123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- # -*- coding: utf-8 -*-
- # @Time : 2022/12/15 13:24
- # @Author : Marvin.yuan
- # @File : unet.py
- # @Description :
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from models import layers
- from utils import registry
- @registry.MODELS.register_module
- class UNet(nn.Module):
- def __init__(self, num_classes, base_width=64, align_corners=False, use_deconv=False):
- super().__init__()
- self.encode = Encoder(base_width)
- self.decode = Decoder(base_width, align_corners, use_deconv)
- self.cls = nn.Conv2d(base_width, num_classes, kernel_size=1)
- self.apply(layers.init_weights)
- def forward(self, x):
- x, short_cuts = self.encode(x)
- x = self.decode(x, short_cuts)
- logit = self.cls(x)
- return logit
- class Encoder(nn.Module):
- def __init__(self, base_width):
- super().__init__()
- self.double_conv = nn.Sequential(
- layers.ConvBNAct(3, base_width, 3), layers.ConvBNAct(base_width, base_width, 3))
- down_channels = [[base_width * 2 ** i, base_width * 2 ** (i + 1)] for i in range(4)]
- self.down_sample_list = nn.ModuleList([
- self.down_sampling(channel[0], channel[1])
- for channel in down_channels
- ])
- @staticmethod
- def down_sampling(in_channels, out_channels):
- return nn.Sequential(
- nn.MaxPool2d(kernel_size=2, stride=2),
- layers.ConvBNAct(in_channels, out_channels, 3),
- layers.ConvBNAct(out_channels, out_channels, 3))
- def forward(self, x):
- short_cuts = []
- x = self.double_conv(x)
- for down_sample in self.down_sample_list:
- short_cuts.append(x)
- x = down_sample(x)
- return x, short_cuts
- class Decoder(nn.Module):
- def __init__(self, base_width, align_corners, use_deconv=False):
- super().__init__()
- up_channels = [[base_width * 2 ** (i + 1), base_width * 2 ** i] for i in reversed(range(4))]
- self.up_sample_list = nn.ModuleList([
- Upsampling(channel[0], channel[1], align_corners, use_deconv)
- for channel in up_channels
- ])
- def forward(self, x, short_cuts):
- for i in range(len(short_cuts)):
- x = self.up_sample_list[i](x, short_cuts[-(i + 1)])
- return x
- class Upsampling(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- align_corners,
- use_deconv=False):
- super().__init__()
- self.align_corners = align_corners
- self.use_deconv = use_deconv
- if use_deconv:
- self.deconv = nn.ConvTranspose2d(
- in_channels, out_channels, kernel_size=2, stride=2, padding=0)
- in_channels = out_channels * 2
- else:
- in_channels = in_channels + out_channels
- self.double_conv = nn.Sequential(
- layers.ConvBNAct(in_channels, out_channels, 3),
- layers.ConvBNAct(out_channels, out_channels, 3))
- def forward(self, x, short_cut):
- if self.use_deconv:
- x = self.deconv(x)
- else:
- x = F.interpolate(
- x,
- short_cut.shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- x = torch.cat([x, short_cut], dim=1)
- x = self.double_conv(x)
- return x
- if __name__ == "__main__":
- dummy_input = torch.rand(1, 3, 256, 256)
- net = UNet(13, base_width=64, align_corners=False, use_deconv=True)
- output = net(dummy_input)
- print(output.shape)
|