bisenetv2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # D:/workplace/python
  2. # -*- coding: utf-8 -*-
  3. # @File :bisenetv2.py
  4. # @Author:Guido LuXiaohao
  5. # @Date :2022/3/11
  6. # @Software:PyCharm
  7. """
  8. Reference from: https://github.com/CoinCheung/BiSeNet
  9. """
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from utils import registry
  14. class ConvBNReLU(nn.Module):
  15. def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
  16. dilation=1, groups=1, bias=False):
  17. super(ConvBNReLU, self).__init__()
  18. self.conv = nn.Conv2d(
  19. in_chan, out_chan, kernel_size=ks, stride=stride,
  20. padding=padding, dilation=dilation,
  21. groups=groups, bias=bias)
  22. self.bn = nn.BatchNorm2d(out_chan)
  23. self.relu = nn.ReLU()
  24. def forward(self, x):
  25. feat = self.conv(x)
  26. feat = self.bn(feat)
  27. feat = self.relu(feat)
  28. return feat
  29. class DetailBranch(nn.Module):
  30. def __init__(self):
  31. super(DetailBranch, self).__init__()
  32. self.S1 = nn.Sequential(
  33. ConvBNReLU(3, 32, 3, stride=2),
  34. ConvBNReLU(32, 32, 3, stride=1),
  35. )
  36. self.S2 = nn.Sequential(
  37. ConvBNReLU(32, 64, 3, stride=2),
  38. ConvBNReLU(64, 64, 3, stride=1),
  39. ConvBNReLU(64, 64, 3, stride=1),
  40. )
  41. self.S3 = nn.Sequential(
  42. ConvBNReLU(64, 128, 3, stride=2),
  43. ConvBNReLU(128, 128, 3, stride=1),
  44. ConvBNReLU(128, 128, 3, stride=1),
  45. )
  46. def forward(self, x):
  47. feat = self.S1(x)
  48. feat = self.S2(feat)
  49. feat = self.S3(feat)
  50. return feat
  51. class StemBlock(nn.Module):
  52. def __init__(self):
  53. super(StemBlock, self).__init__()
  54. self.conv = ConvBNReLU(3, 16, 3, stride=2)
  55. self.left = nn.Sequential(
  56. ConvBNReLU(16, 8, 1, stride=1, padding=0),
  57. ConvBNReLU(8, 16, 3, stride=2),
  58. )
  59. self.right = nn.MaxPool2d(
  60. kernel_size=3, stride=2, padding=1, ceil_mode=False)
  61. self.fuse = ConvBNReLU(32, 16, 3, stride=1)
  62. def forward(self, x):
  63. feat = self.conv(x)
  64. feat_left = self.left(feat)
  65. feat_right = self.right(feat)
  66. feat = torch.cat([feat_left, feat_right], dim=1)
  67. feat = self.fuse(feat)
  68. return feat
  69. class CEBlock(nn.Module):
  70. def __init__(self):
  71. super(CEBlock, self).__init__()
  72. self.bn = nn.BatchNorm2d(128)
  73. self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
  74. self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
  75. def forward(self, x):
  76. feat = torch.mean(x, dim=(2, 3), keepdim=True)
  77. feat = self.bn(feat)
  78. feat = self.conv_gap(feat)
  79. feat = feat + x
  80. feat = self.conv_last(feat)
  81. return feat
  82. class GELayerS1(nn.Module):
  83. def __init__(self, in_chan, out_chan, exp_ratio=6):
  84. super(GELayerS1, self).__init__()
  85. mid_chan = in_chan * exp_ratio
  86. self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
  87. self.dwconv = nn.Sequential(
  88. nn.Conv2d(
  89. in_chan, mid_chan, kernel_size=3, stride=1,
  90. padding=1, groups=in_chan, bias=False),
  91. nn.BatchNorm2d(mid_chan),
  92. nn.ReLU()
  93. )
  94. self.conv2 = nn.Sequential(
  95. nn.Conv2d(
  96. mid_chan, out_chan, kernel_size=1, stride=1,
  97. padding=0, bias=False),
  98. nn.BatchNorm2d(out_chan)
  99. )
  100. self.conv2[1].last_bn = True
  101. self.relu = nn.ReLU()
  102. def forward(self, x):
  103. feat = self.conv1(x)
  104. feat = self.dwconv(feat)
  105. feat = self.conv2(feat)
  106. feat = feat + x
  107. feat = self.relu(feat)
  108. return feat
  109. class GELayerS2(nn.Module):
  110. def __init__(self, in_chan, out_chan, exp_ratio=6):
  111. super(GELayerS2, self).__init__()
  112. mid_chan = in_chan * exp_ratio
  113. self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
  114. self.dwconv1 = nn.Sequential(
  115. nn.Conv2d(
  116. in_chan, mid_chan, kernel_size=3, stride=2,
  117. padding=1, groups=in_chan, bias=False),
  118. nn.BatchNorm2d(mid_chan),
  119. )
  120. self.dwconv2 = nn.Sequential(
  121. nn.Conv2d(
  122. mid_chan, mid_chan, kernel_size=3, stride=1,
  123. padding=1, groups=mid_chan, bias=False),
  124. nn.BatchNorm2d(mid_chan),
  125. nn.ReLU()
  126. )
  127. self.conv2 = nn.Sequential(
  128. nn.Conv2d(
  129. mid_chan, out_chan, kernel_size=1, stride=1,
  130. padding=0, bias=False),
  131. nn.BatchNorm2d(out_chan),
  132. )
  133. self.conv2[1].last_bn = True
  134. self.shortcut = nn.Sequential(
  135. nn.Conv2d(
  136. in_chan, in_chan, kernel_size=3, stride=2,
  137. padding=1, groups=in_chan, bias=False),
  138. nn.BatchNorm2d(in_chan),
  139. nn.Conv2d(
  140. in_chan, out_chan, kernel_size=1, stride=1,
  141. padding=0, bias=False),
  142. nn.BatchNorm2d(out_chan),
  143. )
  144. self.relu = nn.ReLU()
  145. def forward(self, x):
  146. feat = self.conv1(x)
  147. feat = self.dwconv1(feat)
  148. feat = self.dwconv2(feat)
  149. feat = self.conv2(feat)
  150. shortcut = self.shortcut(x)
  151. feat = feat + shortcut
  152. feat = self.relu(feat)
  153. return feat
  154. class SegmentBranch(nn.Module):
  155. def __init__(self):
  156. super(SegmentBranch, self).__init__()
  157. self.S1S2 = StemBlock()
  158. self.S3 = nn.Sequential(
  159. GELayerS2(16, 32),
  160. GELayerS1(32, 32),
  161. )
  162. self.S4 = nn.Sequential(
  163. GELayerS2(32, 64),
  164. GELayerS1(64, 64),
  165. )
  166. self.S5_4 = nn.Sequential(
  167. GELayerS2(64, 128),
  168. GELayerS1(128, 128),
  169. GELayerS1(128, 128),
  170. GELayerS1(128, 128),
  171. )
  172. self.S5_5 = CEBlock()
  173. def forward(self, x):
  174. feat2 = self.S1S2(x)
  175. feat3 = self.S3(feat2)
  176. feat4 = self.S4(feat3)
  177. feat5_4 = self.S5_4(feat4)
  178. feat5_5 = self.S5_5(feat5_4)
  179. return feat2, feat3, feat4, feat5_4, feat5_5
  180. class BGALayer(nn.Module):
  181. def __init__(self):
  182. super(BGALayer, self).__init__()
  183. self.left1 = nn.Sequential(
  184. nn.Conv2d(
  185. 128, 128, kernel_size=3, stride=1,
  186. padding=1, groups=128, bias=False),
  187. nn.BatchNorm2d(128),
  188. nn.Conv2d(
  189. 128, 128, kernel_size=1, stride=1,
  190. padding=0, bias=False),
  191. )
  192. self.left2 = nn.Sequential(
  193. nn.Conv2d(
  194. 128, 128, kernel_size=3, stride=2,
  195. padding=1, bias=False),
  196. nn.BatchNorm2d(128),
  197. nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
  198. )
  199. self.right1 = nn.Sequential(
  200. nn.Conv2d(
  201. 128, 128, kernel_size=3, stride=1,
  202. padding=1, bias=False),
  203. nn.BatchNorm2d(128),
  204. )
  205. self.right2 = nn.Sequential(
  206. nn.Conv2d(
  207. 128, 128, kernel_size=3, stride=1,
  208. padding=1, groups=128, bias=False),
  209. nn.BatchNorm2d(128),
  210. nn.Conv2d(
  211. 128, 128, kernel_size=1, stride=1,
  212. padding=0, bias=False),
  213. )
  214. self.conv = nn.Sequential(
  215. nn.Conv2d(
  216. 128, 128, kernel_size=3, stride=1,
  217. padding=1, bias=False),
  218. nn.BatchNorm2d(128),
  219. nn.ReLU()
  220. )
  221. def forward(self, x_d, x_s):
  222. dsize = x_d.size()[2:]
  223. left1 = self.left1(x_d)
  224. left2 = self.left2(x_d)
  225. right1 = self.right1(x_s)
  226. right2 = self.right2(x_s)
  227. right1 = F.interpolate(
  228. right1, size=dsize, mode='bilinear', align_corners=True)
  229. left = left1 * torch.sigmoid(right1)
  230. right = left2 * torch.sigmoid(right2)
  231. right = F.interpolate(
  232. right, size=dsize, mode='bilinear', align_corners=True)
  233. out = self.conv(left + right)
  234. return out
  235. class BiSeNetHead(nn.Module):
  236. def __init__(self, in_chan, mid_chan, num_classes, dropout_ratio):
  237. super(BiSeNetHead, self).__init__()
  238. self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
  239. self.conv_seg = nn.Conv2d(
  240. mid_chan, num_classes, kernel_size=1, stride=1,
  241. padding=0, bias=False)
  242. if dropout_ratio > 0:
  243. self.dropout = nn.Dropout(dropout_ratio)
  244. else:
  245. self.dropout = None
  246. def forward(self, x, size=None):
  247. feat = self.conv(x)
  248. if self.dropout:
  249. feat = self.dropout(feat)
  250. feat = self.conv_seg(feat)
  251. if size is not None:
  252. feat = F.interpolate(feat, size=size, mode='bilinear', align_corners=True)
  253. return feat
  254. class _DenseASPPConv(nn.Sequential):
  255. def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
  256. drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
  257. super(_DenseASPPConv, self).__init__()
  258. self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
  259. self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
  260. self.add_module('relu1', nn.ReLU()),
  261. self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
  262. self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
  263. self.add_module('relu2', nn.ReLU()),
  264. self.drop_rate = drop_rate
  265. def forward(self, x):
  266. features = super(_DenseASPPConv, self).forward(x)
  267. if self.drop_rate > 0:
  268. features = F.dropout(features, p=self.drop_rate, training=self.training)
  269. return features
  270. class _DenseASPPBlock(nn.Module):
  271. def __init__(self, in_channels, inter_channels1, inter_channels2,
  272. norm_layer=nn.BatchNorm2d, norm_kwargs=None):
  273. super(_DenseASPPBlock, self).__init__()
  274. self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
  275. norm_layer, norm_kwargs)
  276. self.aspp_5 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 5, 0.1,
  277. norm_layer, norm_kwargs)
  278. self.aspp_7 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 7, 0.1,
  279. norm_layer, norm_kwargs)
  280. self.aspp_9 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 9, 0.1,
  281. norm_layer, norm_kwargs)
  282. self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 12, 0.1,
  283. norm_layer, norm_kwargs)
  284. def forward(self, x):
  285. aspp3 = self.aspp_3(x)
  286. x = torch.cat([aspp3, x], dim=1)
  287. aspp6 = self.aspp_5(x)
  288. x = torch.cat([aspp6, x], dim=1)
  289. aspp12 = self.aspp_7(x)
  290. x = torch.cat([aspp12, x], dim=1)
  291. aspp18 = self.aspp_9(x)
  292. x = torch.cat([aspp18, x], dim=1)
  293. aspp24 = self.aspp_12(x)
  294. x = torch.cat([aspp24, x], dim=1)
  295. return x
  296. class _DenseASPPHead(nn.Module):
  297. def __init__(self, in_channels, num_classes, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  298. super(_DenseASPPHead, self).__init__()
  299. self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs)
  300. self.block = nn.Sequential(
  301. nn.Dropout(0.1),
  302. nn.Conv2d(in_channels + 5 * 64, num_classes, 1)
  303. )
  304. def forward(self, x):
  305. x = self.dense_aspp_block(x)
  306. return self.block(x)
  307. @registry.MODELS.register_module
  308. class BiSeNetV2(nn.Module):
  309. def __init__(self, num_classes, dropout_ratio=0.1, apply_nonlin=False):
  310. super(BiSeNetV2, self).__init__()
  311. self.detail = DetailBranch()
  312. self.segment = SegmentBranch()
  313. self.bga = BGALayer()
  314. self.heads = nn.ModuleList([
  315. BiSeNetHead(16, 128, num_classes, dropout_ratio), # aux2 segmentation head
  316. BiSeNetHead(32, 128, num_classes, dropout_ratio), # aux3 segmentation head
  317. BiSeNetHead(64, 128, num_classes, dropout_ratio), # aux4 segmentation head
  318. BiSeNetHead(128, 128, num_classes, dropout_ratio), # aux5_4 segmentation head
  319. BiSeNetHead(128, 128, num_classes, dropout_ratio), # main segmentation head
  320. ])
  321. self.apply_nonlin = apply_nonlin
  322. def forward(self, x):
  323. img_sz = x.size()[2:]
  324. feat_d = self.detail(x)
  325. feat_out = self.segment(x)
  326. feat_out = list(feat_out)
  327. feat_out[-1] = self.bga(feat_d, feat_out[-1])
  328. outputs = []
  329. if self.training:
  330. for i, head in enumerate(self.heads):
  331. out = head(feat_out[i], img_sz)
  332. outputs.append(out)
  333. else:
  334. out = self.heads[-1](feat_out[-1], img_sz)
  335. if self.apply_nonlin:
  336. out = F.softmax(out, dim=1)
  337. outputs.append(out)
  338. return outputs
  339. @registry.MODELS.register_module
  340. class BiseNetV2DenseASPP(nn.Module):
  341. def __init__(self, num_classes, dropout_ratio=0.1, apply_nonlin=False):
  342. super(BiseNetV2DenseASPP, self).__init__()
  343. self.detail = DetailBranch()
  344. self.segment = SegmentBranch()
  345. self.bga = BGALayer()
  346. self.heads = nn.ModuleList([
  347. BiSeNetHead(16, 128, num_classes, dropout_ratio), # aux2 segmentation head
  348. BiSeNetHead(32, 128, num_classes, dropout_ratio), # aux3 segmentation head
  349. BiSeNetHead(64, 128, num_classes, dropout_ratio), # aux4 segmentation head
  350. BiSeNetHead(128, 128, num_classes, dropout_ratio), # aux5_4 segmentation head
  351. BiSeNetHead(128, 128, num_classes, dropout_ratio), # main segmentation head
  352. ])
  353. self.denseaspp_head = _DenseASPPHead(128, 128)
  354. self.apply_nonlin = apply_nonlin
  355. def forward(self, x):
  356. img_sz = x.size()[2:]
  357. feat_d = self.detail(x)
  358. feat_out = self.segment(x)
  359. feat_out = list(feat_out)
  360. feat_out[-1] = self.denseaspp_head(self.bga(feat_d, feat_out[-1]))
  361. outputs = []
  362. if self.training:
  363. for i, head in enumerate(self.heads):
  364. out = head(feat_out[i], img_sz)
  365. outputs.append(out)
  366. else:
  367. out = self.heads[-1](feat_out[-1], img_sz)
  368. if self.apply_nonlin:
  369. out = F.softmax(out, dim=1)
  370. outputs.append(out)
  371. return outputs
  372. if __name__ == '__main__':
  373. dummy_data = torch.rand(2, 3, 256, 256)
  374. model = BiSeNetV2(num_classes=3)
  375. output = model(dummy_data)
  376. print(output)