pp_liteseg.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2022/10/20 16:25
  3. # @Author : Marvin.yuan
  4. # @File : pp_liteseg.py
  5. # @Description :
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from .components.decoders import BaseSegDecoder
  10. from .layers import (ConvBN, ConvBNAct, CustomAdaptiveAvgPool2d,
  11. custom_adaptive_avg_pool2d, custom_adaptive_max_pool2d)
  12. from utils import registry
  13. @registry.MODELS.register_module
  14. class PPLiteSegNeck(nn.Module):
  15. def __init__(self, arm_in_chs, in_index, arm_out_chs, arm_type,
  16. cm_bin_sizes, cm_out_ch, resize_mode='bilinear'):
  17. super().__init__()
  18. self.in_index = in_index if isinstance(in_index, list) else [in_index]
  19. # PP Context Module
  20. self.cm = PPContextModule(arm_in_chs[-1], cm_out_ch, cm_out_ch, cm_bin_sizes)
  21. # Arm Module
  22. arm_class = eval(arm_type)
  23. self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32]
  24. for i in range(len(arm_in_chs)):
  25. low_chs = arm_in_chs[i]
  26. high_ch = cm_out_ch if i == len(
  27. arm_in_chs) - 1 else arm_out_chs[i + 1]
  28. out_ch = arm_out_chs[i]
  29. arm = arm_class(
  30. low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode)
  31. self.arm_list.append(arm)
  32. def forward(self, inputs):
  33. in_feat_list = [inputs[idx] for idx in self.in_index]
  34. high_feat = self.cm(in_feat_list[-1])
  35. out_feat_list = []
  36. for i in reversed(range(len(in_feat_list))):
  37. low_feat = in_feat_list[i]
  38. arm = self.arm_list[i]
  39. high_feat = arm(low_feat, high_feat)
  40. out_feat_list.insert(0, high_feat)
  41. return out_feat_list
  42. @registry.MODELS.register_module
  43. class PPLiteSegDecoder(BaseSegDecoder):
  44. """The Decoder of PPLiteSeg."""
  45. def __init__(self, arm_out_chs, arm_type, cm_bin_sizes, cm_out_ch, resize_mode='bilinear', **kwargs):
  46. super().__init__(**kwargs)
  47. # PP Context Module
  48. self.cm = PPContextModule(self.in_channels[-1], cm_out_ch, cm_out_ch, cm_bin_sizes)
  49. # Arm Module
  50. arm_class = eval(arm_type)
  51. self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32]
  52. for i in range(len(self.in_channels)):
  53. low_chs = self.in_channels[i]
  54. high_ch = cm_out_ch if i == len(
  55. self.in_channels) - 1 else arm_out_chs[i + 1]
  56. out_ch = arm_out_chs[i]
  57. arm = arm_class(
  58. low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode)
  59. self.arm_list.append(arm)
  60. def _forward_features(self, inputs):
  61. """
  62. Args:
  63. inputs (List(Tensor)): Such as [x2, x4, x8, x16, x32].
  64. x2, x4 and x8 are optional.
  65. Returns:
  66. out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32].
  67. x2, x4 and x8 are optional.
  68. The length of in_feat_list and out_feat_list are the same.
  69. """
  70. in_feat_list = self._transform_inputs(inputs)
  71. high_feat = self.cm(in_feat_list[-1])
  72. out_feat_list = []
  73. for i in reversed(range(len(in_feat_list))):
  74. low_feat = in_feat_list[i]
  75. arm = self.arm_list[i]
  76. high_feat = arm(low_feat, high_feat)
  77. out_feat_list.insert(0, high_feat)
  78. return out_feat_list
  79. def forward(self, inputs, data_samples=None):
  80. x = self._forward_features(inputs)
  81. x = self.decode_head(x)
  82. return x
  83. class UAFM(nn.Module):
  84. """
  85. The base of Unified Attention Fusion Module.
  86. Args:
  87. x_ch (int): The channel of x tensor, which is the low level feature.
  88. y_ch (int): The channel of y tensor, which is the high level feature.
  89. out_ch (int): The channel of output tensor.
  90. ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
  91. resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
  92. """
  93. def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
  94. super().__init__()
  95. self.conv_x = ConvBNAct(x_ch, y_ch, kernel_size=ksize)
  96. self.conv_out = ConvBNAct(y_ch, out_ch, kernel_size=3)
  97. self.resize_mode = resize_mode
  98. def check(self, x, y):
  99. assert x.ndim == 4 and y.ndim == 4
  100. x_h, x_w = x.shape[2:]
  101. y_h, y_w = y.shape[2:]
  102. assert x_h >= y_h and x_w >= y_w
  103. def prepare(self, x, y):
  104. x = self.prepare_x(x, y)
  105. y = self.prepare_y(x, y)
  106. return x, y
  107. def prepare_x(self, x, y):
  108. x = self.conv_x(x)
  109. return x
  110. def prepare_y(self, x, y):
  111. y_up = F.interpolate(y, x.shape[-2:], mode=self.resize_mode)
  112. return y_up
  113. def fuse(self, x, y):
  114. out = x + y
  115. out = self.conv_out(out)
  116. return out
  117. def forward(self, x, y):
  118. """
  119. Args:
  120. x (Tensor): The low level feature.
  121. y (Tensor): The high level feature.
  122. """
  123. self.check(x, y)
  124. x, y = self.prepare(x, y)
  125. out = self.fuse(x, y)
  126. return out
  127. class UAFM_ChAtten(UAFM):
  128. def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
  129. """
  130. The UAFM with channel attention, which uses mean and max values.
  131. Args:
  132. x_ch (int): The channel of x tensor, which is the low level feature.
  133. y_ch (int): The channel of y tensor, which is the high level feature.
  134. out_ch (int): The channel of output tensor.
  135. ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
  136. resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
  137. """
  138. super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
  139. self.conv_xy_atten = nn.Sequential(
  140. ConvBNAct(4 * y_ch, y_ch // 2, kernel_size=1),
  141. ConvBN(y_ch // 2, y_ch, kernel_size=1))
  142. @staticmethod
  143. def avg_max_reduce_hw(x):
  144. # Reduce hw by avg and max
  145. # Return cat([avg_pool_0, avg_pool_1, ..., max_pool_0, max_pool_1, ...])
  146. if not isinstance(x, (list, tuple)):
  147. x = [x]
  148. res_avg = []
  149. res_max = []
  150. for xi in x:
  151. avg = custom_adaptive_avg_pool2d(xi, 1)
  152. max = custom_adaptive_max_pool2d(xi, 1)
  153. res_avg.append(avg)
  154. res_max.append(max)
  155. res = res_avg + res_max
  156. return torch.cat(res, dim=1)
  157. def fuse(self, x, y):
  158. """
  159. Args:
  160. x (Tensor): The low level feature.
  161. y (Tensor): The high level feature.
  162. """
  163. atten = self.avg_max_reduce_hw([x, y])
  164. atten = torch.sigmoid(self.conv_xy_atten(atten))
  165. out = x * atten + y * (1 - atten)
  166. out = self.conv_out(out)
  167. return out
  168. class UAFM_ChAtten_S(UAFM):
  169. def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
  170. """
  171. The UAFM with channel attention, which uses mean values.
  172. Args:
  173. x_ch (int): The channel of x tensor, which is the low level feature.
  174. y_ch (int): The channel of y tensor, which is the high level feature.
  175. out_ch (int): The channel of output tensor.
  176. ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
  177. resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
  178. """
  179. super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
  180. self.conv_xy_atten = nn.Sequential(
  181. ConvBNAct(2 * y_ch, y_ch // 2, kernel_size=1),
  182. ConvBN(y_ch // 2, y_ch, kernel_size=1))
  183. @staticmethod
  184. def avg_reduce_hw(x):
  185. # Reduce hw by avg
  186. # Return cat([avg_pool_0, avg_pool_1, ...])
  187. if not isinstance(x, (list, tuple)):
  188. return custom_adaptive_avg_pool2d(x, 1)
  189. elif len(x) == 1:
  190. return custom_adaptive_avg_pool2d(x[0], 1)
  191. else:
  192. res = []
  193. for xi in x:
  194. res.append(custom_adaptive_avg_pool2d(xi, 1))
  195. return torch.concat(res, dim=1)
  196. def fuse(self, x, y):
  197. """
  198. Args:
  199. x (Tensor): The low level feature.
  200. y (Tensor): The high level feature.
  201. """
  202. atten = self.avg_reduce_hw([x, y])
  203. atten = torch.sigmoid(self.conv_xy_atten(atten))
  204. out = x * atten + y * (1 - atten)
  205. out = self.conv_out(out)
  206. return out
  207. class UAFM_SpAtten(UAFM):
  208. """
  209. The UAFM with spatial attention, which uses mean and max values.
  210. Args:
  211. x_ch (int): The channel of x tensor, which is the low level feature.
  212. y_ch (int): The channel of y tensor, which is the high level feature.
  213. out_ch (int): The channel of output tensor.
  214. ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
  215. resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
  216. """
  217. def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
  218. super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
  219. self.conv_xy_atten = nn.Sequential(
  220. ConvBNAct(4, 2, kernel_size=3),
  221. ConvBN(2, 1, kernel_size=3))
  222. @staticmethod
  223. def avg_max_reduce_channel(x):
  224. # Reduce hw by avg and max
  225. # Return cat([avg_ch_0, max_ch_0, avg_ch_1, max_ch_1, ...])
  226. if not isinstance(x, (list, tuple)):
  227. x = [x]
  228. res = []
  229. for xi in x:
  230. mean_value = torch.mean(xi, dim=1, keepdim=True)
  231. max_vaule = torch.max(xi, dim=1, keepdim=True)[0]
  232. res.extend([mean_value, max_vaule])
  233. return torch.cat(res, dim=1)
  234. def fuse(self, x, y):
  235. """
  236. Args:
  237. x (Tensor): The low level feature.
  238. y (Tensor): The high level feature.
  239. """
  240. atten = self.avg_max_reduce_channel([x, y])
  241. atten = torch.sigmoid(self.conv_xy_atten(atten))
  242. out = x * atten + y * (1 - atten)
  243. out = self.conv_out(out)
  244. return out
  245. class UAFM_SpAtten_S(UAFM):
  246. """
  247. The UAFM with spatial attention, which uses mean values.
  248. Args:
  249. x_ch (int): The channel of x tensor, which is the low level feature.
  250. y_ch (int): The channel of y tensor, which is the high level feature.
  251. out_ch (int): The channel of output tensor.
  252. ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
  253. resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
  254. """
  255. def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
  256. super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
  257. self.conv_xy_atten = nn.Sequential(
  258. ConvBNAct(2, 2, kernel_size=3),
  259. ConvBN(2, 1, kernel_size=3))
  260. @staticmethod
  261. def avg_reduce_channel(x):
  262. # Reduce channel by avg
  263. # Return cat([avg_ch_0, avg_ch_1, ...])
  264. if not isinstance(x, (list, tuple)):
  265. return torch.mean(x, dim=1, keepdim=True)
  266. elif len(x) == 1:
  267. return torch.mean(x[0], dim=1, keepdim=True)
  268. else:
  269. res = []
  270. for xi in x:
  271. res.append(torch.mean(xi, dim=1, keepdim=True))
  272. return torch.concat(res, dim=1)
  273. def fuse(self, x, y):
  274. """
  275. Args:
  276. x (Tensor): The low level feature.
  277. y (Tensor): The high level feature.
  278. """
  279. atten = self.avg_reduce_channel([x, y])
  280. atten = torch.sigmoid(self.conv_xy_atten(atten))
  281. out = x * atten + y * (1 - atten)
  282. out = self.conv_out(out)
  283. return
  284. class PPContextModule(nn.Module):
  285. def __init__(self,
  286. in_channels: int,
  287. inter_channels: int,
  288. out_channels: int,
  289. bin_sizes: tuple,
  290. align_corners: bool = False):
  291. super().__init__()
  292. self.stages = nn.ModuleList([
  293. self._make_stage(in_channels, inter_channels, size)
  294. for size in bin_sizes
  295. ])
  296. self.conv_out = ConvBNAct(inter_channels, out_channels, kernel_size=3)
  297. self.align_corners = align_corners
  298. def _make_stage(self, in_channels, out_channels, size):
  299. prior = CustomAdaptiveAvgPool2d(size)
  300. conv = ConvBNAct(in_channels, out_channels, kernel_size=1)
  301. return nn.Sequential(prior, conv)
  302. def forward(self, input):
  303. out = None
  304. input_shape = input.shape[-2:]
  305. for stage in self.stages:
  306. x = stage(input)
  307. x = F.interpolate(
  308. x,
  309. input_shape,
  310. mode='bilinear',
  311. align_corners=self.align_corners)
  312. if out is None:
  313. out = x
  314. else:
  315. out += x
  316. out = self.conv_out(out)
  317. return out