models.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. from tool.torch_utils import *
  5. from tool.yolo_layer import YoloLayer
  6. class Mish(torch.nn.Module):
  7. def __init__(self):
  8. super().__init__()
  9. def forward(self, x):
  10. x = x * (torch.tanh(torch.nn.functional.softplus(x)))
  11. return x
  12. class Upsample(nn.Module):
  13. def __init__(self):
  14. super(Upsample, self).__init__()
  15. def forward(self, x, target_size, inference=False):
  16. assert (x.data.dim() == 4)
  17. # _, _, tH, tW = target_size
  18. if inference:
  19. #B = x.data.size(0)
  20. #C = x.data.size(1)
  21. #H = x.data.size(2)
  22. #W = x.data.size(3)
  23. return x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1).\
  24. expand(x.size(0), x.size(1), x.size(2), target_size[2] // x.size(2), x.size(3), target_size[3] // x.size(3)).\
  25. contiguous().view(x.size(0), x.size(1), target_size[2], target_size[3])
  26. else:
  27. return F.interpolate(x, size=(target_size[2], target_size[3]), mode='nearest')
  28. class Conv_Bn_Activation(nn.Module):
  29. def __init__(self, in_channels, out_channels, kernel_size, stride, activation, bn=True, bias=False):
  30. super().__init__()
  31. pad = (kernel_size - 1) // 2
  32. self.conv = nn.ModuleList()
  33. if bias:
  34. self.conv.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad))
  35. else:
  36. self.conv.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad, bias=False))
  37. if bn:
  38. self.conv.append(nn.BatchNorm2d(out_channels))
  39. if activation == "mish":
  40. self.conv.append(Mish())
  41. elif activation == "relu":
  42. self.conv.append(nn.ReLU(inplace=True))
  43. elif activation == "leaky":
  44. self.conv.append(nn.LeakyReLU(0.1, inplace=True))
  45. elif activation == "linear":
  46. pass
  47. else:
  48. print("activate error !!! {} {} {}".format(sys._getframe().f_code.co_filename,
  49. sys._getframe().f_code.co_name, sys._getframe().f_lineno))
  50. def forward(self, x):
  51. for l in self.conv:
  52. x = l(x)
  53. return x
  54. class ResBlock(nn.Module):
  55. """
  56. Sequential residual blocks each of which consists of \
  57. two convolution layers.
  58. Args:
  59. ch (int): number of input and output channels.
  60. nblocks (int): number of residual blocks.
  61. shortcut (bool): if True, residual tensor addition is enabled.
  62. """
  63. def __init__(self, ch, nblocks=1, shortcut=True):
  64. super().__init__()
  65. self.shortcut = shortcut
  66. self.module_list = nn.ModuleList()
  67. for i in range(nblocks):
  68. resblock_one = nn.ModuleList()
  69. resblock_one.append(Conv_Bn_Activation(ch, ch, 1, 1, 'mish'))
  70. resblock_one.append(Conv_Bn_Activation(ch, ch, 3, 1, 'mish'))
  71. self.module_list.append(resblock_one)
  72. def forward(self, x):
  73. for module in self.module_list:
  74. h = x
  75. for res in module:
  76. h = res(h)
  77. x = x + h if self.shortcut else h
  78. return x
  79. class DownSample1(nn.Module):
  80. def __init__(self):
  81. super().__init__()
  82. self.conv1 = Conv_Bn_Activation(3, 32, 3, 1, 'mish')
  83. self.conv2 = Conv_Bn_Activation(32, 64, 3, 2, 'mish')
  84. self.conv3 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')
  85. # [route]
  86. # layers = -2
  87. self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')
  88. self.conv5 = Conv_Bn_Activation(64, 32, 1, 1, 'mish')
  89. self.conv6 = Conv_Bn_Activation(32, 64, 3, 1, 'mish')
  90. # [shortcut]
  91. # from=-3
  92. # activation = linear
  93. self.conv7 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')
  94. # [route]
  95. # layers = -1, -7
  96. self.conv8 = Conv_Bn_Activation(128, 64, 1, 1, 'mish')
  97. def forward(self, input):
  98. x1 = self.conv1(input)
  99. x2 = self.conv2(x1)
  100. x3 = self.conv3(x2)
  101. # route -2
  102. x4 = self.conv4(x2)
  103. x5 = self.conv5(x4)
  104. x6 = self.conv6(x5)
  105. # shortcut -3
  106. x6 = x6 + x4
  107. x7 = self.conv7(x6)
  108. # [route]
  109. # layers = -1, -7
  110. x7 = torch.cat([x7, x3], dim=1)
  111. x8 = self.conv8(x7)
  112. return x8
  113. class DownSample2(nn.Module):
  114. def __init__(self):
  115. super().__init__()
  116. self.conv1 = Conv_Bn_Activation(64, 128, 3, 2, 'mish')
  117. self.conv2 = Conv_Bn_Activation(128, 64, 1, 1, 'mish')
  118. # r -2
  119. self.conv3 = Conv_Bn_Activation(128, 64, 1, 1, 'mish')
  120. self.resblock = ResBlock(ch=64, nblocks=2)
  121. # s -3
  122. self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')
  123. # r -1 -10
  124. self.conv5 = Conv_Bn_Activation(128, 128, 1, 1, 'mish')
  125. def forward(self, input):
  126. x1 = self.conv1(input)
  127. x2 = self.conv2(x1)
  128. x3 = self.conv3(x1)
  129. r = self.resblock(x3)
  130. x4 = self.conv4(r)
  131. x4 = torch.cat([x4, x2], dim=1)
  132. x5 = self.conv5(x4)
  133. return x5
  134. class DownSample3(nn.Module):
  135. def __init__(self):
  136. super().__init__()
  137. self.conv1 = Conv_Bn_Activation(128, 256, 3, 2, 'mish')
  138. self.conv2 = Conv_Bn_Activation(256, 128, 1, 1, 'mish')
  139. self.conv3 = Conv_Bn_Activation(256, 128, 1, 1, 'mish')
  140. self.resblock = ResBlock(ch=128, nblocks=8)
  141. self.conv4 = Conv_Bn_Activation(128, 128, 1, 1, 'mish')
  142. self.conv5 = Conv_Bn_Activation(256, 256, 1, 1, 'mish')
  143. def forward(self, input):
  144. x1 = self.conv1(input)
  145. x2 = self.conv2(x1)
  146. x3 = self.conv3(x1)
  147. r = self.resblock(x3)
  148. x4 = self.conv4(r)
  149. x4 = torch.cat([x4, x2], dim=1)
  150. x5 = self.conv5(x4)
  151. return x5
  152. class DownSample4(nn.Module):
  153. def __init__(self):
  154. super().__init__()
  155. self.conv1 = Conv_Bn_Activation(256, 512, 3, 2, 'mish')
  156. self.conv2 = Conv_Bn_Activation(512, 256, 1, 1, 'mish')
  157. self.conv3 = Conv_Bn_Activation(512, 256, 1, 1, 'mish')
  158. self.resblock = ResBlock(ch=256, nblocks=8)
  159. self.conv4 = Conv_Bn_Activation(256, 256, 1, 1, 'mish')
  160. self.conv5 = Conv_Bn_Activation(512, 512, 1, 1, 'mish')
  161. def forward(self, input):
  162. x1 = self.conv1(input)
  163. x2 = self.conv2(x1)
  164. x3 = self.conv3(x1)
  165. r = self.resblock(x3)
  166. x4 = self.conv4(r)
  167. x4 = torch.cat([x4, x2], dim=1)
  168. x5 = self.conv5(x4)
  169. return x5
  170. class DownSample5(nn.Module):
  171. def __init__(self):
  172. super().__init__()
  173. self.conv1 = Conv_Bn_Activation(512, 1024, 3, 2, 'mish')
  174. self.conv2 = Conv_Bn_Activation(1024, 512, 1, 1, 'mish')
  175. self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, 'mish')
  176. self.resblock = ResBlock(ch=512, nblocks=4)
  177. self.conv4 = Conv_Bn_Activation(512, 512, 1, 1, 'mish')
  178. self.conv5 = Conv_Bn_Activation(1024, 1024, 1, 1, 'mish')
  179. def forward(self, input):
  180. x1 = self.conv1(input)
  181. x2 = self.conv2(x1)
  182. x3 = self.conv3(x1)
  183. r = self.resblock(x3)
  184. x4 = self.conv4(r)
  185. x4 = torch.cat([x4, x2], dim=1)
  186. x5 = self.conv5(x4)
  187. return x5
  188. class Neck(nn.Module):
  189. def __init__(self, inference=False):
  190. super().__init__()
  191. self.inference = inference
  192. self.conv1 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
  193. self.conv2 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
  194. self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
  195. # SPP
  196. self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
  197. self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)
  198. self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)
  199. # R -1 -3 -5 -6
  200. # SPP
  201. self.conv4 = Conv_Bn_Activation(2048, 512, 1, 1, 'leaky')
  202. self.conv5 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
  203. self.conv6 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
  204. self.conv7 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  205. # UP
  206. self.upsample1 = Upsample()
  207. # R 85
  208. self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  209. # R -1 -3
  210. self.conv9 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  211. self.conv10 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
  212. self.conv11 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  213. self.conv12 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
  214. self.conv13 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  215. self.conv14 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')
  216. # UP
  217. self.upsample2 = Upsample()
  218. # R 54
  219. self.conv15 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')
  220. # R -1 -3
  221. self.conv16 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')
  222. self.conv17 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')
  223. self.conv18 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')
  224. self.conv19 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')
  225. self.conv20 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')
  226. def forward(self, input, downsample4, downsample3, inference=False):
  227. x1 = self.conv1(input)
  228. x2 = self.conv2(x1)
  229. x3 = self.conv3(x2)
  230. # SPP
  231. m1 = self.maxpool1(x3)
  232. m2 = self.maxpool2(x3)
  233. m3 = self.maxpool3(x3)
  234. spp = torch.cat([m3, m2, m1, x3], dim=1)
  235. # SPP end
  236. x4 = self.conv4(spp)
  237. x5 = self.conv5(x4)
  238. x6 = self.conv6(x5)
  239. x7 = self.conv7(x6)
  240. # UP
  241. up = self.upsample1(x7, downsample4.size(), self.inference)
  242. # R 85
  243. x8 = self.conv8(downsample4)
  244. # R -1 -3
  245. x8 = torch.cat([x8, up], dim=1)
  246. x9 = self.conv9(x8)
  247. x10 = self.conv10(x9)
  248. x11 = self.conv11(x10)
  249. x12 = self.conv12(x11)
  250. x13 = self.conv13(x12)
  251. x14 = self.conv14(x13)
  252. # UP
  253. up = self.upsample2(x14, downsample3.size(), self.inference)
  254. # R 54
  255. x15 = self.conv15(downsample3)
  256. # R -1 -3
  257. x15 = torch.cat([x15, up], dim=1)
  258. x16 = self.conv16(x15)
  259. x17 = self.conv17(x16)
  260. x18 = self.conv18(x17)
  261. x19 = self.conv19(x18)
  262. x20 = self.conv20(x19)
  263. return x20, x13, x6
  264. class Yolov4Head(nn.Module):
  265. def __init__(self, output_ch, n_classes, inference=False):
  266. super().__init__()
  267. self.inference = inference
  268. self.conv1 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')
  269. self.conv2 = Conv_Bn_Activation(256, output_ch, 1, 1, 'linear', bn=False, bias=True)
  270. self.yolo1 = YoloLayer(
  271. anchor_mask=[0, 1, 2], num_classes=n_classes,
  272. anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
  273. num_anchors=9, stride=8)
  274. # R -4
  275. self.conv3 = Conv_Bn_Activation(128, 256, 3, 2, 'leaky')
  276. # R -1 -16
  277. self.conv4 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  278. self.conv5 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
  279. self.conv6 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  280. self.conv7 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
  281. self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')
  282. self.conv9 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')
  283. self.conv10 = Conv_Bn_Activation(512, output_ch, 1, 1, 'linear', bn=False, bias=True)
  284. self.yolo2 = YoloLayer(
  285. anchor_mask=[3, 4, 5], num_classes=n_classes,
  286. anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
  287. num_anchors=9, stride=16)
  288. # R -4
  289. self.conv11 = Conv_Bn_Activation(256, 512, 3, 2, 'leaky')
  290. # R -1 -37
  291. self.conv12 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
  292. self.conv13 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
  293. self.conv14 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
  294. self.conv15 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
  295. self.conv16 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')
  296. self.conv17 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')
  297. self.conv18 = Conv_Bn_Activation(1024, output_ch, 1, 1, 'linear', bn=False, bias=True)
  298. self.yolo3 = YoloLayer(
  299. anchor_mask=[6, 7, 8], num_classes=n_classes,
  300. anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
  301. num_anchors=9, stride=32)
  302. def forward(self, input1, input2, input3):
  303. x1 = self.conv1(input1)
  304. x2 = self.conv2(x1)
  305. x3 = self.conv3(input1)
  306. # R -1 -16
  307. x3 = torch.cat([x3, input2], dim=1)
  308. x4 = self.conv4(x3)
  309. x5 = self.conv5(x4)
  310. x6 = self.conv6(x5)
  311. x7 = self.conv7(x6)
  312. x8 = self.conv8(x7)
  313. x9 = self.conv9(x8)
  314. x10 = self.conv10(x9)
  315. # R -4
  316. x11 = self.conv11(x8)
  317. # R -1 -37
  318. x11 = torch.cat([x11, input3], dim=1)
  319. x12 = self.conv12(x11)
  320. x13 = self.conv13(x12)
  321. x14 = self.conv14(x13)
  322. x15 = self.conv15(x14)
  323. x16 = self.conv16(x15)
  324. x17 = self.conv17(x16)
  325. x18 = self.conv18(x17)
  326. if self.inference:
  327. y1 = self.yolo1(x2)
  328. y2 = self.yolo2(x10)
  329. y3 = self.yolo3(x18)
  330. return get_region_boxes([y1, y2, y3])
  331. else:
  332. return [x2, x10, x18]
  333. class Yolov4(nn.Module):
  334. def __init__(self, yolov4conv137weight=None, n_classes=80, inference=False):
  335. super().__init__()
  336. output_ch = (4 + 1 + n_classes) * 3
  337. # backbone
  338. self.down1 = DownSample1()
  339. self.down2 = DownSample2()
  340. self.down3 = DownSample3()
  341. self.down4 = DownSample4()
  342. self.down5 = DownSample5()
  343. # neck
  344. self.neek = Neck(inference)
  345. # yolov4conv137
  346. if yolov4conv137weight:
  347. _model = nn.Sequential(self.down1, self.down2, self.down3, self.down4, self.down5, self.neek)
  348. pretrained_dict = torch.load(yolov4conv137weight)
  349. model_dict = _model.state_dict()
  350. # 1. filter out unnecessary keys
  351. pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}
  352. # 2. overwrite entries in the existing state dict
  353. model_dict.update(pretrained_dict)
  354. _model.load_state_dict(model_dict)
  355. # head
  356. self.head = Yolov4Head(output_ch, n_classes, inference)
  357. def forward(self, input):
  358. d1 = self.down1(input)
  359. d2 = self.down2(d1)
  360. d3 = self.down3(d2)
  361. d4 = self.down4(d3)
  362. d5 = self.down5(d4)
  363. x20, x13, x6 = self.neek(d5, d4, d3)
  364. output = self.head(x20, x13, x6)
  365. return output
  366. if __name__ == "__main__":
  367. import sys
  368. import cv2
  369. namesfile = None
  370. if len(sys.argv) == 6:
  371. n_classes = int(sys.argv[1])
  372. weightfile = sys.argv[2]
  373. imgfile = sys.argv[3]
  374. height = int(sys.argv[4])
  375. width = int(sys.argv[5])
  376. elif len(sys.argv) == 7:
  377. n_classes = int(sys.argv[1])
  378. weightfile = sys.argv[2]
  379. imgfile = sys.argv[3]
  380. height = sys.argv[4]
  381. width = int(sys.argv[5])
  382. namesfile = int(sys.argv[6])
  383. else:
  384. print('Usage: ')
  385. print(' python models.py num_classes weightfile imgfile namefile')
  386. model = Yolov4(yolov4conv137weight=None, n_classes=n_classes, inference=True)
  387. pretrained_dict = torch.load(weightfile, map_location=torch.device('cuda'))
  388. model.load_state_dict(pretrained_dict)
  389. use_cuda = True
  390. if use_cuda:
  391. model.cuda()
  392. img = cv2.imread(imgfile)
  393. # Inference input size is 416*416 does not mean training size is the same
  394. # Training size could be 608*608 or even other sizes
  395. # Optional inference sizes:
  396. # Hight in {320, 416, 512, 608, ... 320 + 96 * n}
  397. # Width in {320, 416, 512, 608, ... 320 + 96 * m}
  398. sized = cv2.resize(img, (width, height))
  399. sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB)
  400. from tool.utils import load_class_names, plot_boxes_cv2
  401. from tool.torch_utils import do_detect
  402. for i in range(2): # This 'for' loop is for speed check
  403. # Because the first iteration is usually longer
  404. boxes = do_detect(model, sized, 0.4, 0.6, use_cuda)
  405. if namesfile == None:
  406. if n_classes == 20:
  407. namesfile = 'data/voc.names'
  408. elif n_classes == 80:
  409. namesfile = 'data/coco.names'
  410. else:
  411. print("please give namefile")
  412. class_names = load_class_names(namesfile)
  413. plot_boxes_cv2(img, boxes[0], 'predictions.jpg', class_names)