123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- # -*- coding: utf-8 -*-
- # @Time : 2022/11/16 15:30
- # @Author : Marvin.yuan
- # @File : unext.py
- # @Description : https://github.com/jeya-maria-jose/UNeXt-pytorch/blob/6ad0855114a35afbf81decf5dc912cd8de70476a/archs.py#L206
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from models import layers
- from models.layers.drop import DropPath
- from utils import registry
- @registry.MODELS.register_module
- class UNeXt(nn.Module):
- """3 Conv + 2 shifted MLP"""
- def __init__(self,
- num_classes,
- input_channels=3,
- img_size=224,
- embed_dims=[128, 160, 256],
- drop_rate=0.,
- drop_path_rate=0,
- norm_layer=nn.LayerNorm,
- depths=[1, 1, 1],
- apply_nonlin=False,
- **kwargs):
- super(UNeXt, self).__init__()
- self.encoder1 = nn.Conv2d(input_channels, 16, 3, stride=1, padding=1)
- self.encoder2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
- self.encoder3 = nn.Conv2d(32, 128, 3, stride=1, padding=1)
- self.ebn1 = nn.BatchNorm2d(16)
- self.ebn2 = nn.BatchNorm2d(32)
- self.ebn3 = nn.BatchNorm2d(128)
- self.norm3 = norm_layer(embed_dims[1])
- self.norm4 = norm_layer(embed_dims[2])
- self.dnorm3 = norm_layer(160)
- self.dnorm4 = norm_layer(128)
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
- self.blcok1 = nn.ModuleList([TokenizedMLPBlock(
- dim=embed_dims[1], mlp_ratio=1, drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer)])
- self.block2 = nn.ModuleList([TokenizedMLPBlock(
- dim=embed_dims[2], mlp_ratio=1, drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer)])
- self.dblock1 = nn.ModuleList([TokenizedMLPBlock(
- dim=embed_dims[1], mlp_ratio=1, drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer)])
- self.dblock2 = nn.ModuleList([TokenizedMLPBlock(
- dim=embed_dims[0], mlp_ratio=1, drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer)])
- self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2,
- in_chans=embed_dims[0], embed_dim=embed_dims[1])
- self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2,
- in_chans=embed_dims[1], embed_dim=embed_dims[2])
- self.decoder1 = layers.ConvBNAct(256, 160, 3)
- self.decoder2 = layers.ConvBNAct(160, 128, 3)
- self.decoder3 = layers.ConvBNAct(128, 32, 3)
- self.decoder4 = layers.ConvBNAct(32, 16, 3)
- self.decoder5 = nn.Conv2d(16, 16, 3, padding=1, bias=False)
- self.dbn1 = nn.BatchNorm2d(160)
- self.dbn2 = nn.BatchNorm2d(128)
- self.dbn3 = nn.BatchNorm2d(32)
- self.dbn4 = nn.BatchNorm2d(16)
- self.conv_seg = nn.Conv2d(16, num_classes, kernel_size=1)
- self.apply_nonlin = apply_nonlin
- def forward(self, x):
- B = x.shape[0]
- # Encoder
- # Conv Stage
- # Stage 1
- out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)), 2, 2))
- t1 = out
- # Stage 2
- out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)), 2, 2))
- t2 = out
- # Stage 3
- out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)), 2, 2))
- t3 = out
- # Tokenized MLP Stage
- # Stage 4
- out, H, W = self.patch_embed3(out)
- for blk in self.blcok1:
- out = blk(out, H, W)
- out = self.norm3(out)
- out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- t4 = out
- # Bottleneck
- out, H, W = self.patch_embed4(out)
- for blk in self.block2:
- out = blk(out, H, W)
- out = self.norm4(out)
- out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- # Decoder
- # Stage 4
- out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)), scale_factor=(2, 2), mode='bilinear'))
- out = torch.add(out, t4)
- _, _, H, W = out.shape
- out = out.flatten(2).transpose(1, 2)
- for blk in self.dblock1:
- out = blk(out, H, W)
- # Stage 3
- out = self.dnorm3(out)
- out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)), scale_factor=(2, 2), mode='bilinear'))
- out = torch.add(out, t3)
- _, _, H, W = out.shape
- out = out.flatten(2).transpose(1, 2)
- for blk in self.dblock2:
- out = blk(out, H, W)
- out = self.dnorm4(out)
- out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- # Stage 2
- out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)), scale_factor=(2, 2), mode='bilinear'))
- out = torch.add(out, t2)
- # Stage 1
- out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)), scale_factor=(2, 2), mode='bilinear'))
- out = torch.add(out, t1)
- out = F.relu(F.interpolate(self.decoder5(out), scale_factor=(2, 2), mode='bilinear'))
- out = self.conv_seg(out)
- if self.apply_nonlin:
- out = F.softmax(out, dim=1)
- return out
- class DWConv(nn.Module):
- def __init__(self, dim=768):
- super().__init__()
- self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
- def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
- B, N, C = x.shape
- x = x.transpose(1, 2).view(B, C, H, W)
- x = self.dwconv(x)
- x = x.flatten(2).transpose(1, 2)
- return x
- class ShiftedMLP(nn.Module):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.,
- shift_size=5):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.dim = in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.dwconv = DWConv(hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- self.shift_size = shift_size
- self.pad = shift_size // 2
- self.apply(layers.init_weights)
- def forward(self, x: torch.Tensor, H: int, W: int):
- B, N, C = x.shape
- # shift across width
- xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
- xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
- xs = torch.chunk(xn, self.shift_size, 1)
- x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
- x_cat = torch.cat(x_shift, 1)
- x_cat = torch.narrow(x_cat, 2, self.pad, H)
- x_s = torch.narrow(x_cat, 3, self.pad, W)
- x_s = x_s.reshape(B, C, H*W).contiguous()
- x_shift_r = x_s.transpose(1, 2)
- x = self.fc1(x_shift_r)
- x = self.dwconv(x, H, W)
- x = self.act(x)
- x = self.drop(x)
- # shift across height
- xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
- xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
- xs = torch.chunk(xn, self.shift_size, 1)
- x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad + 1))]
- x_cat = torch.cat(x_shift, 1)
- x_cat = torch.narrow(x_cat, 2, self.pad, H)
- x_s = torch.narrow(x_cat, 3, self.pad, W)
- x_s = x_s.reshape(B, C, H * W).contiguous()
- x_shift_c = x_s.transpose(1, 2)
- x = self.fc2(x_shift_c)
- x = self.drop(x)
- return x
- class TokenizedMLPBlock(nn.Module):
- def __init__(self,
- dim,
- mlp_ratio=4.,
- drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm):
- super().__init__()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = ShiftedMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
- self.apply(layers.init_weights)
- def forward(self, x, H, W):
- x = x + self.drop_path(self.mlp(self.norm(x), H, W))
- return x
- class OverlapPatchEmbed(nn.Module):
- """Image to Patch Embedding"""
- def __init__(self,
- img_size=224,
- patch_size=7,
- stride=4,
- in_chans=3,
- embed_dim=768):
- super().__init__()
- img_size = (img_size, img_size) if isinstance(img_size, int) else img_size[:2]
- patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size[:2]
- self.img_size = img_size
- self.patch_size = patch_size
- self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
- self.num_patches = self.H * self.W
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
- padding=(patch_size[0] // 2, patch_size[1] // 2))
- self.norm = nn.LayerNorm(embed_dim)
- self.apply(layers.init_weights)
- def forward(self, x):
- x = self.proj(x)
- _, _, H, W = x.shape
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
- return x, H, W
- if __name__ == "__main__":
- model = UNeXt(3, img_size=256)
- model.eval()
- dummy = torch.rand(2, 3, 256, 256)
- logits = model(dummy)
- print(logits)
|