unext.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2022/11/16 15:30
  3. # @Author : Marvin.yuan
  4. # @File : unext.py
  5. # @Description : https://github.com/jeya-maria-jose/UNeXt-pytorch/blob/6ad0855114a35afbf81decf5dc912cd8de70476a/archs.py#L206
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from models import layers
  10. from models.layers.drop import DropPath
  11. from utils import registry
  12. @registry.MODELS.register_module
  13. class UNeXt(nn.Module):
  14. """3 Conv + 2 shifted MLP"""
  15. def __init__(self,
  16. num_classes,
  17. input_channels=3,
  18. img_size=224,
  19. embed_dims=[128, 160, 256],
  20. drop_rate=0.,
  21. drop_path_rate=0,
  22. norm_layer=nn.LayerNorm,
  23. depths=[1, 1, 1],
  24. apply_nonlin=False,
  25. **kwargs):
  26. super(UNeXt, self).__init__()
  27. self.encoder1 = nn.Conv2d(input_channels, 16, 3, stride=1, padding=1)
  28. self.encoder2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
  29. self.encoder3 = nn.Conv2d(32, 128, 3, stride=1, padding=1)
  30. self.ebn1 = nn.BatchNorm2d(16)
  31. self.ebn2 = nn.BatchNorm2d(32)
  32. self.ebn3 = nn.BatchNorm2d(128)
  33. self.norm3 = norm_layer(embed_dims[1])
  34. self.norm4 = norm_layer(embed_dims[2])
  35. self.dnorm3 = norm_layer(160)
  36. self.dnorm4 = norm_layer(128)
  37. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  38. self.blcok1 = nn.ModuleList([TokenizedMLPBlock(
  39. dim=embed_dims[1], mlp_ratio=1, drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer)])
  40. self.block2 = nn.ModuleList([TokenizedMLPBlock(
  41. dim=embed_dims[2], mlp_ratio=1, drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer)])
  42. self.dblock1 = nn.ModuleList([TokenizedMLPBlock(
  43. dim=embed_dims[1], mlp_ratio=1, drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer)])
  44. self.dblock2 = nn.ModuleList([TokenizedMLPBlock(
  45. dim=embed_dims[0], mlp_ratio=1, drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer)])
  46. self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2,
  47. in_chans=embed_dims[0], embed_dim=embed_dims[1])
  48. self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2,
  49. in_chans=embed_dims[1], embed_dim=embed_dims[2])
  50. self.decoder1 = layers.ConvBNAct(256, 160, 3)
  51. self.decoder2 = layers.ConvBNAct(160, 128, 3)
  52. self.decoder3 = layers.ConvBNAct(128, 32, 3)
  53. self.decoder4 = layers.ConvBNAct(32, 16, 3)
  54. self.decoder5 = nn.Conv2d(16, 16, 3, padding=1, bias=False)
  55. self.dbn1 = nn.BatchNorm2d(160)
  56. self.dbn2 = nn.BatchNorm2d(128)
  57. self.dbn3 = nn.BatchNorm2d(32)
  58. self.dbn4 = nn.BatchNorm2d(16)
  59. self.conv_seg = nn.Conv2d(16, num_classes, kernel_size=1)
  60. self.apply_nonlin = apply_nonlin
  61. def forward(self, x):
  62. B = x.shape[0]
  63. # Encoder
  64. # Conv Stage
  65. # Stage 1
  66. out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)), 2, 2))
  67. t1 = out
  68. # Stage 2
  69. out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)), 2, 2))
  70. t2 = out
  71. # Stage 3
  72. out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)), 2, 2))
  73. t3 = out
  74. # Tokenized MLP Stage
  75. # Stage 4
  76. out, H, W = self.patch_embed3(out)
  77. for blk in self.blcok1:
  78. out = blk(out, H, W)
  79. out = self.norm3(out)
  80. out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  81. t4 = out
  82. # Bottleneck
  83. out, H, W = self.patch_embed4(out)
  84. for blk in self.block2:
  85. out = blk(out, H, W)
  86. out = self.norm4(out)
  87. out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  88. # Decoder
  89. # Stage 4
  90. out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)), scale_factor=(2, 2), mode='bilinear'))
  91. out = torch.add(out, t4)
  92. _, _, H, W = out.shape
  93. out = out.flatten(2).transpose(1, 2)
  94. for blk in self.dblock1:
  95. out = blk(out, H, W)
  96. # Stage 3
  97. out = self.dnorm3(out)
  98. out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  99. out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)), scale_factor=(2, 2), mode='bilinear'))
  100. out = torch.add(out, t3)
  101. _, _, H, W = out.shape
  102. out = out.flatten(2).transpose(1, 2)
  103. for blk in self.dblock2:
  104. out = blk(out, H, W)
  105. out = self.dnorm4(out)
  106. out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  107. # Stage 2
  108. out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)), scale_factor=(2, 2), mode='bilinear'))
  109. out = torch.add(out, t2)
  110. # Stage 1
  111. out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)), scale_factor=(2, 2), mode='bilinear'))
  112. out = torch.add(out, t1)
  113. out = F.relu(F.interpolate(self.decoder5(out), scale_factor=(2, 2), mode='bilinear'))
  114. out = self.conv_seg(out)
  115. if self.apply_nonlin:
  116. out = F.softmax(out, dim=1)
  117. return out
  118. class DWConv(nn.Module):
  119. def __init__(self, dim=768):
  120. super().__init__()
  121. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  122. def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
  123. B, N, C = x.shape
  124. x = x.transpose(1, 2).view(B, C, H, W)
  125. x = self.dwconv(x)
  126. x = x.flatten(2).transpose(1, 2)
  127. return x
  128. class ShiftedMLP(nn.Module):
  129. def __init__(self,
  130. in_features,
  131. hidden_features=None,
  132. out_features=None,
  133. act_layer=nn.GELU,
  134. drop=0.,
  135. shift_size=5):
  136. super().__init__()
  137. out_features = out_features or in_features
  138. hidden_features = hidden_features or in_features
  139. self.dim = in_features
  140. self.fc1 = nn.Linear(in_features, hidden_features)
  141. self.dwconv = DWConv(hidden_features)
  142. self.act = act_layer()
  143. self.fc2 = nn.Linear(hidden_features, out_features)
  144. self.drop = nn.Dropout(drop)
  145. self.shift_size = shift_size
  146. self.pad = shift_size // 2
  147. self.apply(layers.init_weights)
  148. def forward(self, x: torch.Tensor, H: int, W: int):
  149. B, N, C = x.shape
  150. # shift across width
  151. xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
  152. xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
  153. xs = torch.chunk(xn, self.shift_size, 1)
  154. x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
  155. x_cat = torch.cat(x_shift, 1)
  156. x_cat = torch.narrow(x_cat, 2, self.pad, H)
  157. x_s = torch.narrow(x_cat, 3, self.pad, W)
  158. x_s = x_s.reshape(B, C, H*W).contiguous()
  159. x_shift_r = x_s.transpose(1, 2)
  160. x = self.fc1(x_shift_r)
  161. x = self.dwconv(x, H, W)
  162. x = self.act(x)
  163. x = self.drop(x)
  164. # shift across height
  165. xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
  166. xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
  167. xs = torch.chunk(xn, self.shift_size, 1)
  168. x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
  169. x_cat = torch.cat(x_shift, 1)
  170. x_cat = torch.narrow(x_cat, 2, self.pad, H)
  171. x_s = torch.narrow(x_cat, 3, self.pad, W)
  172. x_s = x_s.reshape(B, C, H * W).contiguous()
  173. x_shift_c = x_s.transpose(1, 2)
  174. x = self.fc2(x_shift_c)
  175. x = self.drop(x)
  176. return x
  177. class TokenizedMLPBlock(nn.Module):
  178. def __init__(self,
  179. dim,
  180. mlp_ratio=4.,
  181. drop=0.,
  182. drop_path=0.,
  183. act_layer=nn.GELU,
  184. norm_layer=nn.LayerNorm):
  185. super().__init__()
  186. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  187. self.norm = norm_layer(dim)
  188. mlp_hidden_dim = int(dim * mlp_ratio)
  189. self.mlp = ShiftedMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  190. self.apply(layers.init_weights)
  191. def forward(self, x, H, W):
  192. x = x + self.drop_path(self.mlp(self.norm(x), H, W))
  193. return x
  194. class OverlapPatchEmbed(nn.Module):
  195. """Image to Patch Embedding"""
  196. def __init__(self,
  197. img_size=224,
  198. patch_size=7,
  199. stride=4,
  200. in_chans=3,
  201. embed_dim=768):
  202. super().__init__()
  203. img_size = (img_size, img_size) if isinstance(img_size, int) else img_size[:2]
  204. patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size[:2]
  205. self.img_size = img_size
  206. self.patch_size = patch_size
  207. self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
  208. self.num_patches = self.H * self.W
  209. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
  210. padding=(patch_size[0] // 2, patch_size[1] // 2))
  211. self.norm = nn.LayerNorm(embed_dim)
  212. self.apply(layers.init_weights)
  213. def forward(self, x):
  214. x = self.proj(x)
  215. _, _, H, W = x.shape
  216. x = x.flatten(2).transpose(1, 2)
  217. x = self.norm(x)
  218. return x, H, W
  219. if __name__ == "__main__":
  220. model = UNeXt(3, img_size=256)
  221. model.eval()
  222. dummy = torch.rand(2, 3, 256, 256)
  223. logits = model(dummy)
  224. print(logits)