train.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. # -*- coding: utf-8 -*-
  2. '''
  3. @Time : 2020/05/06 15:07
  4. @Author : Tianxiaomo
  5. @File : train.py
  6. @Noice :
  7. @Modificattion :
  8. @Author :
  9. @Time :
  10. @Detail :
  11. '''
  12. import time
  13. import logging
  14. import os, sys, math
  15. import argparse
  16. from collections import deque
  17. import datetime
  18. import cv2
  19. from tqdm import tqdm
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from torch.utils.data import DataLoader
  24. from torch import optim
  25. from torch.nn import functional as F
  26. from tensorboardX import SummaryWriter
  27. from easydict import EasyDict as edict
  28. from dataset import Yolo_dataset
  29. from cfg import Cfg
  30. from models import Yolov4
  31. from tool.darknet2pytorch import Darknet
  32. from tool.tv_reference.utils import collate_fn as val_collate
  33. from tool.tv_reference.coco_utils import convert_to_coco_api
  34. from tool.tv_reference.coco_eval import CocoEvaluator
  35. def bboxes_iou(bboxes_a, bboxes_b, xyxy=True, GIoU=False, DIoU=False, CIoU=False):
  36. """Calculate the Intersection of Unions (IoUs) between bounding boxes.
  37. IoU is calculated as a ratio of area of the intersection
  38. and area of the union.
  39. Args:
  40. bbox_a (array): An array whose shape is :math:`(N, 4)`.
  41. :math:`N` is the number of bounding boxes.
  42. The dtype should be :obj:`numpy.float32`.
  43. bbox_b (array): An array similar to :obj:`bbox_a`,
  44. whose shape is :math:`(K, 4)`.
  45. The dtype should be :obj:`numpy.float32`.
  46. Returns:
  47. array:
  48. An array whose shape is :math:`(N, K)`. \
  49. An element at index :math:`(n, k)` contains IoUs between \
  50. :math:`n` th bounding box in :obj:`bbox_a` and :math:`k` th bounding \
  51. box in :obj:`bbox_b`.
  52. from: https://github.com/chainer/chainercv
  53. https://github.com/ultralytics/yolov3/blob/eca5b9c1d36e4f73bf2f94e141d864f1c2739e23/utils/utils.py#L262-L282
  54. """
  55. if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
  56. raise IndexError
  57. if xyxy:
  58. # intersection top left
  59. tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
  60. # intersection bottom right
  61. br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
  62. # convex (smallest enclosing box) top left and bottom right
  63. con_tl = torch.min(bboxes_a[:, None, :2], bboxes_b[:, :2])
  64. con_br = torch.max(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
  65. # centerpoint distance squared
  66. rho2 = ((bboxes_a[:, None, 0] + bboxes_a[:, None, 2]) - (bboxes_b[:, 0] + bboxes_b[:, 2])) ** 2 / 4 + (
  67. (bboxes_a[:, None, 1] + bboxes_a[:, None, 3]) - (bboxes_b[:, 1] + bboxes_b[:, 3])) ** 2 / 4
  68. w1 = bboxes_a[:, 2] - bboxes_a[:, 0]
  69. h1 = bboxes_a[:, 3] - bboxes_a[:, 1]
  70. w2 = bboxes_b[:, 2] - bboxes_b[:, 0]
  71. h2 = bboxes_b[:, 3] - bboxes_b[:, 1]
  72. area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
  73. area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
  74. else:
  75. # intersection top left
  76. tl = torch.max((bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
  77. (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2))
  78. # intersection bottom right
  79. br = torch.min((bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
  80. (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2))
  81. # convex (smallest enclosing box) top left and bottom right
  82. con_tl = torch.min((bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
  83. (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2))
  84. con_br = torch.max((bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
  85. (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2))
  86. # centerpoint distance squared
  87. rho2 = ((bboxes_a[:, None, :2] - bboxes_b[:, :2]) ** 2 / 4).sum(dim=-1)
  88. w1 = bboxes_a[:, 2]
  89. h1 = bboxes_a[:, 3]
  90. w2 = bboxes_b[:, 2]
  91. h2 = bboxes_b[:, 3]
  92. area_a = torch.prod(bboxes_a[:, 2:], 1)
  93. area_b = torch.prod(bboxes_b[:, 2:], 1)
  94. en = (tl < br).type(tl.type()).prod(dim=2)
  95. area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
  96. area_u = area_a[:, None] + area_b - area_i
  97. iou = area_i / area_u
  98. if GIoU or DIoU or CIoU:
  99. if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
  100. area_c = torch.prod(con_br - con_tl, 2) # convex area
  101. return iou - (area_c - area_u) / area_c # GIoU
  102. if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  103. # convex diagonal squared
  104. c2 = torch.pow(con_br - con_tl, 2).sum(dim=2) + 1e-16
  105. if DIoU:
  106. return iou - rho2 / c2 # DIoU
  107. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  108. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w1 / h1).unsqueeze(1) - torch.atan(w2 / h2), 2)
  109. with torch.no_grad():
  110. alpha = v / (1 - iou + v)
  111. return iou - (rho2 / c2 + v * alpha) # CIoU
  112. return iou
  113. class Yolo_loss(nn.Module):
  114. def __init__(self, n_classes=80, n_anchors=3, device=None, batch=2):
  115. super(Yolo_loss, self).__init__()
  116. self.device = device
  117. self.strides = [8, 16, 32]
  118. image_size = 608
  119. self.n_classes = n_classes
  120. self.n_anchors = n_anchors
  121. self.anchors = [[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], [72, 146], [142, 110], [192, 243], [459, 401]]
  122. self.anch_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  123. self.ignore_thre = 0.5
  124. self.masked_anchors, self.ref_anchors, self.grid_x, self.grid_y, self.anchor_w, self.anchor_h = [], [], [], [], [], []
  125. for i in range(3):
  126. all_anchors_grid = [(w / self.strides[i], h / self.strides[i]) for w, h in self.anchors]
  127. masked_anchors = np.array([all_anchors_grid[j] for j in self.anch_masks[i]], dtype=np.float32)
  128. ref_anchors = np.zeros((len(all_anchors_grid), 4), dtype=np.float32)
  129. ref_anchors[:, 2:] = np.array(all_anchors_grid, dtype=np.float32)
  130. ref_anchors = torch.from_numpy(ref_anchors)
  131. # calculate pred - xywh obj cls
  132. fsize = image_size // self.strides[i]
  133. grid_x = torch.arange(fsize, dtype=torch.float).repeat(batch, 3, fsize, 1).to(device)
  134. grid_y = torch.arange(fsize, dtype=torch.float).repeat(batch, 3, fsize, 1).permute(0, 1, 3, 2).to(device)
  135. anchor_w = torch.from_numpy(masked_anchors[:, 0]).repeat(batch, fsize, fsize, 1).permute(0, 3, 1, 2).to(
  136. device)
  137. anchor_h = torch.from_numpy(masked_anchors[:, 1]).repeat(batch, fsize, fsize, 1).permute(0, 3, 1, 2).to(
  138. device)
  139. self.masked_anchors.append(masked_anchors)
  140. self.ref_anchors.append(ref_anchors)
  141. self.grid_x.append(grid_x)
  142. self.grid_y.append(grid_y)
  143. self.anchor_w.append(anchor_w)
  144. self.anchor_h.append(anchor_h)
  145. def build_target(self, pred, labels, batchsize, fsize, n_ch, output_id):
  146. # target assignment
  147. tgt_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 4 + self.n_classes).to(device=self.device)
  148. obj_mask = torch.ones(batchsize, self.n_anchors, fsize, fsize).to(device=self.device)
  149. tgt_scale = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 2).to(self.device)
  150. target = torch.zeros(batchsize, self.n_anchors, fsize, fsize, n_ch).to(self.device)
  151. # labels = labels.cpu().data
  152. nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
  153. truth_x_all = (labels[:, :, 2] + labels[:, :, 0]) / (self.strides[output_id] * 2)
  154. truth_y_all = (labels[:, :, 3] + labels[:, :, 1]) / (self.strides[output_id] * 2)
  155. truth_w_all = (labels[:, :, 2] - labels[:, :, 0]) / self.strides[output_id]
  156. truth_h_all = (labels[:, :, 3] - labels[:, :, 1]) / self.strides[output_id]
  157. truth_i_all = truth_x_all.to(torch.int16).cpu().numpy()
  158. truth_j_all = truth_y_all.to(torch.int16).cpu().numpy()
  159. for b in range(batchsize):
  160. n = int(nlabel[b])
  161. if n == 0:
  162. continue
  163. truth_box = torch.zeros(n, 4).to(self.device)
  164. truth_box[:n, 2] = truth_w_all[b, :n]
  165. truth_box[:n, 3] = truth_h_all[b, :n]
  166. truth_i = truth_i_all[b, :n]
  167. truth_j = truth_j_all[b, :n]
  168. # calculate iou between truth and reference anchors
  169. anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors[output_id], CIoU=True)
  170. # temp = bbox_iou(truth_box.cpu(), self.ref_anchors[output_id])
  171. best_n_all = anchor_ious_all.argmax(dim=1)
  172. best_n = best_n_all % 3
  173. best_n_mask = ((best_n_all == self.anch_masks[output_id][0]) |
  174. (best_n_all == self.anch_masks[output_id][1]) |
  175. (best_n_all == self.anch_masks[output_id][2]))
  176. if sum(best_n_mask) == 0:
  177. continue
  178. truth_box[:n, 0] = truth_x_all[b, :n]
  179. truth_box[:n, 1] = truth_y_all[b, :n]
  180. pred_ious = bboxes_iou(pred[b].view(-1, 4), truth_box, xyxy=False)
  181. pred_best_iou, _ = pred_ious.max(dim=1)
  182. pred_best_iou = (pred_best_iou > self.ignore_thre)
  183. pred_best_iou = pred_best_iou.view(pred[b].shape[:3])
  184. # set mask to zero (ignore) if pred matches truth
  185. obj_mask[b] = ~ pred_best_iou
  186. for ti in range(best_n.shape[0]):
  187. if best_n_mask[ti] == 1:
  188. i, j = truth_i[ti], truth_j[ti]
  189. a = best_n[ti]
  190. obj_mask[b, a, j, i] = 1
  191. tgt_mask[b, a, j, i, :] = 1
  192. target[b, a, j, i, 0] = truth_x_all[b, ti] - truth_x_all[b, ti].to(torch.int16).to(torch.float)
  193. target[b, a, j, i, 1] = truth_y_all[b, ti] - truth_y_all[b, ti].to(torch.int16).to(torch.float)
  194. target[b, a, j, i, 2] = torch.log(
  195. truth_w_all[b, ti] / torch.Tensor(self.masked_anchors[output_id])[best_n[ti], 0] + 1e-16)
  196. target[b, a, j, i, 3] = torch.log(
  197. truth_h_all[b, ti] / torch.Tensor(self.masked_anchors[output_id])[best_n[ti], 1] + 1e-16)
  198. target[b, a, j, i, 4] = 1
  199. target[b, a, j, i, 5 + labels[b, ti, 4].to(torch.int16).cpu().numpy()] = 1
  200. tgt_scale[b, a, j, i, :] = torch.sqrt(2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize)
  201. return obj_mask, tgt_mask, tgt_scale, target
  202. def forward(self, xin, labels=None):
  203. loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = 0, 0, 0, 0, 0, 0
  204. for output_id, output in enumerate(xin):
  205. batchsize = output.shape[0]
  206. fsize = output.shape[2]
  207. n_ch = 5 + self.n_classes
  208. output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
  209. output = output.permute(0, 1, 3, 4, 2) # .contiguous()
  210. # logistic activation for xy, obj, cls
  211. output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(output[..., np.r_[:2, 4:n_ch]])
  212. pred = output[..., :4].clone()
  213. pred[..., 0] += self.grid_x[output_id]
  214. pred[..., 1] += self.grid_y[output_id]
  215. pred[..., 2] = torch.exp(pred[..., 2]) * self.anchor_w[output_id]
  216. pred[..., 3] = torch.exp(pred[..., 3]) * self.anchor_h[output_id]
  217. obj_mask, tgt_mask, tgt_scale, target = self.build_target(pred, labels, batchsize, fsize, n_ch, output_id)
  218. # loss calculation
  219. output[..., 4] *= obj_mask
  220. output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
  221. output[..., 2:4] *= tgt_scale
  222. target[..., 4] *= obj_mask
  223. target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
  224. target[..., 2:4] *= tgt_scale
  225. loss_xy += F.binary_cross_entropy(input=output[..., :2], target=target[..., :2],
  226. weight=tgt_scale * tgt_scale, reduction='sum')
  227. loss_wh += F.mse_loss(input=output[..., 2:4], target=target[..., 2:4], reduction='sum') / 2
  228. loss_obj += F.binary_cross_entropy(input=output[..., 4], target=target[..., 4], reduction='sum')
  229. loss_cls += F.binary_cross_entropy(input=output[..., 5:], target=target[..., 5:], reduction='sum')
  230. loss_l2 += F.mse_loss(input=output, target=target, reduction='sum')
  231. loss = loss_xy + loss_wh + loss_obj + loss_cls
  232. return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2
  233. def collate(batch):
  234. images = []
  235. bboxes = []
  236. for img, box in batch:
  237. images.append([img])
  238. bboxes.append([box])
  239. images = np.concatenate(images, axis=0)
  240. images = images.transpose(0, 3, 1, 2)
  241. images = torch.from_numpy(images).div(255.0)
  242. bboxes = np.concatenate(bboxes, axis=0)
  243. bboxes = torch.from_numpy(bboxes)
  244. return images, bboxes
  245. def train(model, device, config, epochs=5, batch_size=1, save_cp=True, log_step=20, img_scale=0.5):
  246. train_dataset = Yolo_dataset(config.train_label, config, train=True)
  247. val_dataset = Yolo_dataset(config.val_label, config, train=False)
  248. n_train = len(train_dataset)
  249. n_val = len(val_dataset)
  250. train_loader = DataLoader(train_dataset, batch_size=config.batch // config.subdivisions, shuffle=True,
  251. num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate)
  252. val_loader = DataLoader(val_dataset, batch_size=config.batch // config.subdivisions, shuffle=True, num_workers=8,
  253. pin_memory=True, drop_last=True, collate_fn=val_collate)
  254. writer = SummaryWriter(log_dir=config.TRAIN_TENSORBOARD_DIR,
  255. filename_suffix=f'OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}',
  256. comment=f'OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}')
  257. # writer.add_images('legend',
  258. # torch.from_numpy(train_dataset.label2colorlegend2(cfg.DATA_CLASSES).transpose([2, 0, 1])).to(
  259. # device).unsqueeze(0))
  260. max_itr = config.TRAIN_EPOCHS * n_train
  261. # global_step = cfg.TRAIN_MINEPOCH * n_train
  262. global_step = 0
  263. logging.info(f'''Starting training:
  264. Epochs: {epochs}
  265. Batch size: {config.batch}
  266. Subdivisions: {config.subdivisions}
  267. Learning rate: {config.learning_rate}
  268. Training size: {n_train}
  269. Validation size: {n_val}
  270. Checkpoints: {save_cp}
  271. Device: {device.type}
  272. Images size: {config.width}
  273. Optimizer: {config.TRAIN_OPTIMIZER}
  274. Dataset classes: {config.classes}
  275. Train label path:{config.train_label}
  276. Pretrained:
  277. ''')
  278. # learning rate setup
  279. def burnin_schedule(i):
  280. if i < config.burn_in:
  281. factor = pow(i / config.burn_in, 4)
  282. elif i < config.steps[0]:
  283. factor = 1.0
  284. elif i < config.steps[1]:
  285. factor = 0.1
  286. else:
  287. factor = 0.01
  288. return factor
  289. if config.TRAIN_OPTIMIZER.lower() == 'adam':
  290. optimizer = optim.Adam(
  291. model.parameters(),
  292. lr=config.learning_rate / config.batch,
  293. betas=(0.9, 0.999),
  294. eps=1e-08,
  295. )
  296. elif config.TRAIN_OPTIMIZER.lower() == 'sgd':
  297. optimizer = optim.SGD(
  298. params=model.parameters(),
  299. lr=config.learning_rate / config.batch,
  300. momentum=config.momentum,
  301. weight_decay=config.decay,
  302. )
  303. scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
  304. criterion = Yolo_loss(device=device, batch=config.batch // config.subdivisions, n_classes=config.classes)
  305. # scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience=6, min_lr=1e-7)
  306. # scheduler = CosineAnnealingWarmRestarts(optimizer, 0.001, 1e-6, 20)
  307. save_prefix = 'Yolov4_epoch'
  308. saved_models = deque()
  309. model.train()
  310. for epoch in range(epochs):
  311. # model.train()
  312. epoch_loss = 0
  313. epoch_step = 0
  314. with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img', ncols=50) as pbar:
  315. for i, batch in enumerate(train_loader):
  316. global_step += 1
  317. epoch_step += 1
  318. images = batch[0]
  319. bboxes = batch[1]
  320. images = images.to(device=device, dtype=torch.float32)
  321. bboxes = bboxes.to(device=device)
  322. bboxes_pred = model(images)
  323. loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = criterion(bboxes_pred, bboxes)
  324. # loss = loss / config.subdivisions
  325. loss.backward()
  326. epoch_loss += loss.item()
  327. if global_step % config.subdivisions == 0:
  328. optimizer.step()
  329. scheduler.step()
  330. model.zero_grad()
  331. if global_step % (log_step * config.subdivisions) == 0:
  332. writer.add_scalar('train/Loss', loss.item(), global_step)
  333. writer.add_scalar('train/loss_xy', loss_xy.item(), global_step)
  334. writer.add_scalar('train/loss_wh', loss_wh.item(), global_step)
  335. writer.add_scalar('train/loss_obj', loss_obj.item(), global_step)
  336. writer.add_scalar('train/loss_cls', loss_cls.item(), global_step)
  337. writer.add_scalar('train/loss_l2', loss_l2.item(), global_step)
  338. writer.add_scalar('lr', scheduler.get_lr()[0] * config.batch, global_step)
  339. pbar.set_postfix(**{'loss (batch)': loss.item(), 'loss_xy': loss_xy.item(),
  340. 'loss_wh': loss_wh.item(),
  341. 'loss_obj': loss_obj.item(),
  342. 'loss_cls': loss_cls.item(),
  343. 'loss_l2': loss_l2.item(),
  344. 'lr': scheduler.get_lr()[0] * config.batch
  345. })
  346. logging.debug('Train step_{}: loss : {},loss xy : {},loss wh : {},'
  347. 'loss obj : {},loss cls : {},loss l2 : {},lr : {}'
  348. .format(global_step, loss.item(), loss_xy.item(),
  349. loss_wh.item(), loss_obj.item(),
  350. loss_cls.item(), loss_l2.item(),
  351. scheduler.get_lr()[0] * config.batch))
  352. pbar.update(images.shape[0])
  353. if cfg.use_darknet_cfg:
  354. eval_model = Darknet(cfg.cfgfile, inference=True)
  355. else:
  356. eval_model = Yolov4(cfg.pretrained, n_classes=cfg.classes, inference=True)
  357. # eval_model = Yolov4(yolov4conv137weight=None, n_classes=config.classes, inference=True)
  358. if torch.cuda.device_count() > 1:
  359. eval_model.load_state_dict(model.module.state_dict())
  360. else:
  361. eval_model.load_state_dict(model.state_dict())
  362. eval_model.to(device)
  363. evaluator = evaluate(eval_model, val_loader, config, device)
  364. del eval_model
  365. stats = evaluator.coco_eval['bbox'].stats
  366. writer.add_scalar('train/AP', stats[0], global_step)
  367. writer.add_scalar('train/AP50', stats[1], global_step)
  368. writer.add_scalar('train/AP75', stats[2], global_step)
  369. writer.add_scalar('train/AP_small', stats[3], global_step)
  370. writer.add_scalar('train/AP_medium', stats[4], global_step)
  371. writer.add_scalar('train/AP_large', stats[5], global_step)
  372. writer.add_scalar('train/AR1', stats[6], global_step)
  373. writer.add_scalar('train/AR10', stats[7], global_step)
  374. writer.add_scalar('train/AR100', stats[8], global_step)
  375. writer.add_scalar('train/AR_small', stats[9], global_step)
  376. writer.add_scalar('train/AR_medium', stats[10], global_step)
  377. writer.add_scalar('train/AR_large', stats[11], global_step)
  378. if save_cp:
  379. try:
  380. # os.mkdir(config.checkpoints)
  381. os.makedirs(config.checkpoints, exist_ok=True)
  382. logging.info('Created checkpoint directory')
  383. except OSError:
  384. pass
  385. save_path = os.path.join(config.checkpoints, f'{save_prefix}{epoch + 1}.pth')
  386. torch.save(model.state_dict(), save_path)
  387. logging.info(f'Checkpoint {epoch + 1} saved !')
  388. saved_models.append(save_path)
  389. if len(saved_models) > config.keep_checkpoint_max > 0:
  390. model_to_remove = saved_models.popleft()
  391. try:
  392. os.remove(model_to_remove)
  393. except:
  394. logging.info(f'failed to remove {model_to_remove}')
  395. writer.close()
  396. @torch.no_grad()
  397. def evaluate(model, data_loader, cfg, device, logger=None, **kwargs):
  398. """ finished, tested
  399. """
  400. # cpu_device = torch.device("cpu")
  401. model.eval()
  402. # header = 'Test:'
  403. coco = convert_to_coco_api(data_loader.dataset, bbox_fmt='coco')
  404. coco_evaluator = CocoEvaluator(coco, iou_types = ["bbox"], bbox_fmt='coco')
  405. for images, targets in data_loader:
  406. model_input = [[cv2.resize(img, (cfg.w, cfg.h))] for img in images]
  407. model_input = np.concatenate(model_input, axis=0)
  408. model_input = model_input.transpose(0, 3, 1, 2)
  409. model_input = torch.from_numpy(model_input).div(255.0)
  410. model_input = model_input.to(device)
  411. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  412. if torch.cuda.is_available():
  413. torch.cuda.synchronize()
  414. model_time = time.time()
  415. outputs = model(model_input)
  416. # outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
  417. model_time = time.time() - model_time
  418. # outputs = outputs.cpu().detach().numpy()
  419. res = {}
  420. # for img, target, output in zip(images, targets, outputs):
  421. for img, target, boxes, confs in zip(images, targets, outputs[0], outputs[1]):
  422. img_height, img_width = img.shape[:2]
  423. # boxes = output[...,:4].copy() # output boxes in yolo format
  424. boxes = boxes.squeeze(2).cpu().detach().numpy()
  425. boxes[...,2:] = boxes[...,2:] - boxes[...,:2] # Transform [x1, y1, x2, y2] to [x1, y1, w, h]
  426. boxes[...,0] = boxes[...,0]*img_width
  427. boxes[...,1] = boxes[...,1]*img_height
  428. boxes[...,2] = boxes[...,2]*img_width
  429. boxes[...,3] = boxes[...,3]*img_height
  430. boxes = torch.as_tensor(boxes, dtype=torch.float32)
  431. # confs = output[...,4:].copy()
  432. confs = confs.cpu().detach().numpy()
  433. labels = np.argmax(confs, axis=1).flatten()
  434. labels = torch.as_tensor(labels, dtype=torch.int64)
  435. scores = np.max(confs, axis=1).flatten()
  436. scores = torch.as_tensor(scores, dtype=torch.float32)
  437. res[target["image_id"].item()] = {
  438. "boxes": boxes,
  439. "scores": scores,
  440. "labels": labels,
  441. }
  442. evaluator_time = time.time()
  443. coco_evaluator.update(res)
  444. evaluator_time = time.time() - evaluator_time
  445. # gather the stats from all processes
  446. coco_evaluator.synchronize_between_processes()
  447. # accumulate predictions from all images
  448. coco_evaluator.accumulate()
  449. coco_evaluator.summarize()
  450. return coco_evaluator
  451. def get_args(**kwargs):
  452. cfg = kwargs
  453. parser = argparse.ArgumentParser(description='Train the Model on images and target masks',
  454. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  455. # parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=2,
  456. # help='Batch size', dest='batchsize')
  457. parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.001,
  458. help='Learning rate', dest='learning_rate')
  459. parser.add_argument('-f', '--load', dest='load', type=str, default=None,
  460. help='Load model from a .pth file')
  461. parser.add_argument('-g', '--gpu', metavar='G', type=str, default='-1',
  462. help='GPU', dest='gpu')
  463. parser.add_argument('-dir', '--data-dir', type=str, default=None,
  464. help='dataset dir', dest='dataset_dir')
  465. parser.add_argument('-pretrained', type=str, default=None, help='pretrained yolov4.conv.137')
  466. parser.add_argument('-classes', type=int, default=80, help='dataset classes')
  467. parser.add_argument('-train_label_path', dest='train_label', type=str, default='train.txt', help="train label path")
  468. parser.add_argument(
  469. '-optimizer', type=str, default='adam',
  470. help='training optimizer',
  471. dest='TRAIN_OPTIMIZER')
  472. parser.add_argument(
  473. '-iou-type', type=str, default='iou',
  474. help='iou type (iou, giou, diou, ciou)',
  475. dest='iou_type')
  476. parser.add_argument(
  477. '-keep-checkpoint-max', type=int, default=10,
  478. help='maximum number of checkpoints to keep. If set 0, all checkpoints will be kept',
  479. dest='keep_checkpoint_max')
  480. args = vars(parser.parse_args())
  481. # for k in args.keys():
  482. # cfg[k] = args.get(k)
  483. cfg.update(args)
  484. return edict(cfg)
  485. def init_logger(log_file=None, log_dir=None, log_level=logging.INFO, mode='w', stdout=True):
  486. """
  487. log_dir: 日志文件的文件夹路径
  488. mode: 'a', append; 'w', 覆盖原文件写入.
  489. """
  490. def get_date_str():
  491. now = datetime.datetime.now()
  492. return now.strftime('%Y-%m-%d_%H-%M-%S')
  493. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'
  494. if log_dir is None:
  495. log_dir = '~/temp/log/'
  496. if log_file is None:
  497. log_file = 'log_' + get_date_str() + '.txt'
  498. if not os.path.exists(log_dir):
  499. os.makedirs(log_dir)
  500. log_file = os.path.join(log_dir, log_file)
  501. # 此处不能使用logging输出
  502. print('log file path:' + log_file)
  503. logging.basicConfig(level=logging.DEBUG,
  504. format=fmt,
  505. filename=log_file,
  506. filemode=mode)
  507. if stdout:
  508. console = logging.StreamHandler(stream=sys.stdout)
  509. console.setLevel(log_level)
  510. formatter = logging.Formatter(fmt)
  511. console.setFormatter(formatter)
  512. logging.getLogger('').addHandler(console)
  513. return logging
  514. def _get_date_str():
  515. now = datetime.datetime.now()
  516. return now.strftime('%Y-%m-%d_%H-%M')
  517. if __name__ == "__main__":
  518. logging = init_logger(log_dir='log')
  519. cfg = get_args(**Cfg)
  520. os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu
  521. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  522. logging.info(f'Using device {device}')
  523. if cfg.use_darknet_cfg:
  524. model = Darknet(cfg.cfgfile)
  525. else:
  526. model = Yolov4(cfg.pretrained, n_classes=cfg.classes)
  527. if torch.cuda.device_count() > 1:
  528. model = torch.nn.DataParallel(model)
  529. model.to(device=device)
  530. try:
  531. train(model=model,
  532. config=cfg,
  533. epochs=cfg.TRAIN_EPOCHS,
  534. device=device, )
  535. except KeyboardInterrupt:
  536. torch.save(model.state_dict(), 'INTERRUPTED.pth')
  537. logging.info('Saved interrupt')
  538. try:
  539. sys.exit(0)
  540. except SystemExit:
  541. os._exit(0)