unet.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2022/12/15 13:24
  3. # @Author : Marvin.yuan
  4. # @File : unet.py
  5. # @Description :
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from models import layers
  10. from utils import registry
  11. @registry.MODELS.register_module
  12. class UNet(nn.Module):
  13. def __init__(self, num_classes, base_width=64, align_corners=False, use_deconv=False):
  14. super().__init__()
  15. self.encode = Encoder(base_width)
  16. self.decode = Decoder(base_width, align_corners, use_deconv)
  17. self.cls = nn.Conv2d(base_width, num_classes, kernel_size=1)
  18. self.apply(layers.init_weights)
  19. def forward(self, x):
  20. x, short_cuts = self.encode(x)
  21. x = self.decode(x, short_cuts)
  22. logit = self.cls(x)
  23. return logit
  24. class Encoder(nn.Module):
  25. def __init__(self, base_width):
  26. super().__init__()
  27. self.double_conv = nn.Sequential(
  28. layers.ConvBNAct(3, base_width, 3), layers.ConvBNAct(base_width, base_width, 3))
  29. down_channels = [[base_width * 2 ** i, base_width * 2 ** (i + 1)] for i in range(4)]
  30. self.down_sample_list = nn.ModuleList([
  31. self.down_sampling(channel[0], channel[1])
  32. for channel in down_channels
  33. ])
  34. @staticmethod
  35. def down_sampling(in_channels, out_channels):
  36. return nn.Sequential(
  37. nn.MaxPool2d(kernel_size=2, stride=2),
  38. layers.ConvBNAct(in_channels, out_channels, 3),
  39. layers.ConvBNAct(out_channels, out_channels, 3))
  40. def forward(self, x):
  41. short_cuts = []
  42. x = self.double_conv(x)
  43. for down_sample in self.down_sample_list:
  44. short_cuts.append(x)
  45. x = down_sample(x)
  46. return x, short_cuts
  47. class Decoder(nn.Module):
  48. def __init__(self, base_width, align_corners, use_deconv=False):
  49. super().__init__()
  50. up_channels = [[base_width * 2 ** (i + 1), base_width * 2 ** i] for i in reversed(range(4))]
  51. self.up_sample_list = nn.ModuleList([
  52. Upsampling(channel[0], channel[1], align_corners, use_deconv)
  53. for channel in up_channels
  54. ])
  55. def forward(self, x, short_cuts):
  56. for i in range(len(short_cuts)):
  57. x = self.up_sample_list[i](x, short_cuts[-(i + 1)])
  58. return x
  59. class Upsampling(nn.Module):
  60. def __init__(self,
  61. in_channels,
  62. out_channels,
  63. align_corners,
  64. use_deconv=False):
  65. super().__init__()
  66. self.align_corners = align_corners
  67. self.use_deconv = use_deconv
  68. if use_deconv:
  69. self.deconv = nn.ConvTranspose2d(
  70. in_channels, out_channels, kernel_size=2, stride=2, padding=0)
  71. in_channels = out_channels * 2
  72. else:
  73. in_channels = in_channels + out_channels
  74. self.double_conv = nn.Sequential(
  75. layers.ConvBNAct(in_channels, out_channels, 3),
  76. layers.ConvBNAct(out_channels, out_channels, 3))
  77. def forward(self, x, short_cut):
  78. if self.use_deconv:
  79. x = self.deconv(x)
  80. else:
  81. x = F.interpolate(
  82. x,
  83. short_cut.shape[2:],
  84. mode='bilinear',
  85. align_corners=self.align_corners)
  86. x = torch.cat([x, short_cut], dim=1)
  87. x = self.double_conv(x)
  88. return x
  89. if __name__ == "__main__":
  90. dummy_input = torch.rand(1, 3, 256, 256)
  91. net = UNet(13, base_width=64, align_corners=False, use_deconv=True)
  92. output = net(dummy_input)
  93. print(output.shape)