123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- # D:/workplace/python
- # -*- coding: utf-8 -*-
- # @File :bisenetv2.py
- # @Author:Guido LuXiaohao
- # @Date :2022/3/11
- # @Software:PyCharm
- """
- Reference from: https://github.com/CoinCheung/BiSeNet
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from utils import registry
- class ConvBNReLU(nn.Module):
- def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
- dilation=1, groups=1, bias=False):
- super(ConvBNReLU, self).__init__()
- self.conv = nn.Conv2d(
- in_chan, out_chan, kernel_size=ks, stride=stride,
- padding=padding, dilation=dilation,
- groups=groups, bias=bias)
- self.bn = nn.BatchNorm2d(out_chan)
- self.relu = nn.ReLU()
- def forward(self, x):
- feat = self.conv(x)
- feat = self.bn(feat)
- feat = self.relu(feat)
- return feat
- class DetailBranch(nn.Module):
- def __init__(self):
- super(DetailBranch, self).__init__()
- self.S1 = nn.Sequential(
- ConvBNReLU(3, 32, 3, stride=2),
- ConvBNReLU(32, 32, 3, stride=1),
- )
- self.S2 = nn.Sequential(
- ConvBNReLU(32, 64, 3, stride=2),
- ConvBNReLU(64, 64, 3, stride=1),
- ConvBNReLU(64, 64, 3, stride=1),
- )
- self.S3 = nn.Sequential(
- ConvBNReLU(64, 128, 3, stride=2),
- ConvBNReLU(128, 128, 3, stride=1),
- ConvBNReLU(128, 128, 3, stride=1),
- )
- def forward(self, x):
- feat = self.S1(x)
- feat = self.S2(feat)
- feat = self.S3(feat)
- return feat
- class StemBlock(nn.Module):
- def __init__(self):
- super(StemBlock, self).__init__()
- self.conv = ConvBNReLU(3, 16, 3, stride=2)
- self.left = nn.Sequential(
- ConvBNReLU(16, 8, 1, stride=1, padding=0),
- ConvBNReLU(8, 16, 3, stride=2),
- )
- self.right = nn.MaxPool2d(
- kernel_size=3, stride=2, padding=1, ceil_mode=False)
- self.fuse = ConvBNReLU(32, 16, 3, stride=1)
- def forward(self, x):
- feat = self.conv(x)
- feat_left = self.left(feat)
- feat_right = self.right(feat)
- feat = torch.cat([feat_left, feat_right], dim=1)
- feat = self.fuse(feat)
- return feat
- class CEBlock(nn.Module):
- def __init__(self):
- super(CEBlock, self).__init__()
- self.bn = nn.BatchNorm2d(128)
- self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
- self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
- def forward(self, x):
- feat = torch.mean(x, dim=(2, 3), keepdim=True)
- feat = self.bn(feat)
- feat = self.conv_gap(feat)
- feat = feat + x
- feat = self.conv_last(feat)
- return feat
- class GELayerS1(nn.Module):
- def __init__(self, in_chan, out_chan, exp_ratio=6):
- super(GELayerS1, self).__init__()
- mid_chan = in_chan * exp_ratio
- self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
- self.dwconv = nn.Sequential(
- nn.Conv2d(
- in_chan, mid_chan, kernel_size=3, stride=1,
- padding=1, groups=in_chan, bias=False),
- nn.BatchNorm2d(mid_chan),
- nn.ReLU()
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(
- mid_chan, out_chan, kernel_size=1, stride=1,
- padding=0, bias=False),
- nn.BatchNorm2d(out_chan)
- )
- self.conv2[1].last_bn = True
- self.relu = nn.ReLU()
- def forward(self, x):
- feat = self.conv1(x)
- feat = self.dwconv(feat)
- feat = self.conv2(feat)
- feat = feat + x
- feat = self.relu(feat)
- return feat
- class GELayerS2(nn.Module):
- def __init__(self, in_chan, out_chan, exp_ratio=6):
- super(GELayerS2, self).__init__()
- mid_chan = in_chan * exp_ratio
- self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
- self.dwconv1 = nn.Sequential(
- nn.Conv2d(
- in_chan, mid_chan, kernel_size=3, stride=2,
- padding=1, groups=in_chan, bias=False),
- nn.BatchNorm2d(mid_chan),
- )
- self.dwconv2 = nn.Sequential(
- nn.Conv2d(
- mid_chan, mid_chan, kernel_size=3, stride=1,
- padding=1, groups=mid_chan, bias=False),
- nn.BatchNorm2d(mid_chan),
- nn.ReLU()
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(
- mid_chan, out_chan, kernel_size=1, stride=1,
- padding=0, bias=False),
- nn.BatchNorm2d(out_chan),
- )
- self.conv2[1].last_bn = True
- self.shortcut = nn.Sequential(
- nn.Conv2d(
- in_chan, in_chan, kernel_size=3, stride=2,
- padding=1, groups=in_chan, bias=False),
- nn.BatchNorm2d(in_chan),
- nn.Conv2d(
- in_chan, out_chan, kernel_size=1, stride=1,
- padding=0, bias=False),
- nn.BatchNorm2d(out_chan),
- )
- self.relu = nn.ReLU()
- def forward(self, x):
- feat = self.conv1(x)
- feat = self.dwconv1(feat)
- feat = self.dwconv2(feat)
- feat = self.conv2(feat)
- shortcut = self.shortcut(x)
- feat = feat + shortcut
- feat = self.relu(feat)
- return feat
- class SegmentBranch(nn.Module):
- def __init__(self):
- super(SegmentBranch, self).__init__()
- self.S1S2 = StemBlock()
- self.S3 = nn.Sequential(
- GELayerS2(16, 32),
- GELayerS1(32, 32),
- )
- self.S4 = nn.Sequential(
- GELayerS2(32, 64),
- GELayerS1(64, 64),
- )
- self.S5_4 = nn.Sequential(
- GELayerS2(64, 128),
- GELayerS1(128, 128),
- GELayerS1(128, 128),
- GELayerS1(128, 128),
- )
- self.S5_5 = CEBlock()
- def forward(self, x):
- feat2 = self.S1S2(x)
- feat3 = self.S3(feat2)
- feat4 = self.S4(feat3)
- feat5_4 = self.S5_4(feat4)
- feat5_5 = self.S5_5(feat5_4)
- return feat2, feat3, feat4, feat5_4, feat5_5
- class BGALayer(nn.Module):
- def __init__(self):
- super(BGALayer, self).__init__()
- self.left1 = nn.Sequential(
- nn.Conv2d(
- 128, 128, kernel_size=3, stride=1,
- padding=1, groups=128, bias=False),
- nn.BatchNorm2d(128),
- nn.Conv2d(
- 128, 128, kernel_size=1, stride=1,
- padding=0, bias=False),
- )
- self.left2 = nn.Sequential(
- nn.Conv2d(
- 128, 128, kernel_size=3, stride=2,
- padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
- )
- self.right1 = nn.Sequential(
- nn.Conv2d(
- 128, 128, kernel_size=3, stride=1,
- padding=1, bias=False),
- nn.BatchNorm2d(128),
- )
- self.right2 = nn.Sequential(
- nn.Conv2d(
- 128, 128, kernel_size=3, stride=1,
- padding=1, groups=128, bias=False),
- nn.BatchNorm2d(128),
- nn.Conv2d(
- 128, 128, kernel_size=1, stride=1,
- padding=0, bias=False),
- )
- self.conv = nn.Sequential(
- nn.Conv2d(
- 128, 128, kernel_size=3, stride=1,
- padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU()
- )
- def forward(self, x_d, x_s):
- dsize = x_d.size()[2:]
- left1 = self.left1(x_d)
- left2 = self.left2(x_d)
- right1 = self.right1(x_s)
- right2 = self.right2(x_s)
- right1 = F.interpolate(
- right1, size=dsize, mode='bilinear', align_corners=True)
- left = left1 * torch.sigmoid(right1)
- right = left2 * torch.sigmoid(right2)
- right = F.interpolate(
- right, size=dsize, mode='bilinear', align_corners=True)
- out = self.conv(left + right)
- return out
- class BiSeNetHead(nn.Module):
- def __init__(self, in_chan, mid_chan, num_classes, dropout_ratio):
- super(BiSeNetHead, self).__init__()
- self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
- self.conv_seg = nn.Conv2d(
- mid_chan, num_classes, kernel_size=1, stride=1,
- padding=0, bias=False)
- if dropout_ratio > 0:
- self.dropout = nn.Dropout(dropout_ratio)
- else:
- self.dropout = None
- def forward(self, x, size=None):
- feat = self.conv(x)
- if self.dropout:
- feat = self.dropout(feat)
- feat = self.conv_seg(feat)
- if size is not None:
- feat = F.interpolate(feat, size=size, mode='bilinear', align_corners=True)
- return feat
- class _DenseASPPConv(nn.Sequential):
- def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
- drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
- super(_DenseASPPConv, self).__init__()
- self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
- self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
- self.add_module('relu1', nn.ReLU()),
- self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
- self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
- self.add_module('relu2', nn.ReLU()),
- self.drop_rate = drop_rate
- def forward(self, x):
- features = super(_DenseASPPConv, self).forward(x)
- if self.drop_rate > 0:
- features = F.dropout(features, p=self.drop_rate, training=self.training)
- return features
- class _DenseASPPBlock(nn.Module):
- def __init__(self, in_channels, inter_channels1, inter_channels2,
- norm_layer=nn.BatchNorm2d, norm_kwargs=None):
- super(_DenseASPPBlock, self).__init__()
- self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_5 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 5, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_7 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 7, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_9 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 9, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 12, 0.1,
- norm_layer, norm_kwargs)
- def forward(self, x):
- aspp3 = self.aspp_3(x)
- x = torch.cat([aspp3, x], dim=1)
- aspp6 = self.aspp_5(x)
- x = torch.cat([aspp6, x], dim=1)
- aspp12 = self.aspp_7(x)
- x = torch.cat([aspp12, x], dim=1)
- aspp18 = self.aspp_9(x)
- x = torch.cat([aspp18, x], dim=1)
- aspp24 = self.aspp_12(x)
- x = torch.cat([aspp24, x], dim=1)
- return x
- class _DenseASPPHead(nn.Module):
- def __init__(self, in_channels, num_classes, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
- super(_DenseASPPHead, self).__init__()
- self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs)
- self.block = nn.Sequential(
- nn.Dropout(0.1),
- nn.Conv2d(in_channels + 5 * 64, num_classes, 1)
- )
- def forward(self, x):
- x = self.dense_aspp_block(x)
- return self.block(x)
- @registry.MODELS.register_module
- class BiSeNetV2(nn.Module):
- def __init__(self, num_classes, dropout_ratio=0.1, apply_nonlin=False):
- super(BiSeNetV2, self).__init__()
- self.detail = DetailBranch()
- self.segment = SegmentBranch()
- self.bga = BGALayer()
- self.heads = nn.ModuleList([
- BiSeNetHead(16, 128, num_classes, dropout_ratio), # aux2 segmentation head
- BiSeNetHead(32, 128, num_classes, dropout_ratio), # aux3 segmentation head
- BiSeNetHead(64, 128, num_classes, dropout_ratio), # aux4 segmentation head
- BiSeNetHead(128, 128, num_classes, dropout_ratio), # aux5_4 segmentation head
- BiSeNetHead(128, 128, num_classes, dropout_ratio), # main segmentation head
- ])
- self.apply_nonlin = apply_nonlin
- def forward(self, x):
- img_sz = x.size()[2:]
- feat_d = self.detail(x)
- feat_out = self.segment(x)
- feat_out = list(feat_out)
- feat_out[-1] = self.bga(feat_d, feat_out[-1])
- outputs = []
- if self.training:
- for i, head in enumerate(self.heads):
- out = head(feat_out[i], img_sz)
- outputs.append(out)
- else:
- out = self.heads[-1](feat_out[-1], img_sz)
- if self.apply_nonlin:
- out = F.softmax(out, dim=1)
- outputs.append(out)
- return outputs
- @registry.MODELS.register_module
- class BiseNetV2DenseASPP(nn.Module):
- def __init__(self, num_classes, dropout_ratio=0.1, apply_nonlin=False):
- super(BiseNetV2DenseASPP, self).__init__()
- self.detail = DetailBranch()
- self.segment = SegmentBranch()
- self.bga = BGALayer()
- self.heads = nn.ModuleList([
- BiSeNetHead(16, 128, num_classes, dropout_ratio), # aux2 segmentation head
- BiSeNetHead(32, 128, num_classes, dropout_ratio), # aux3 segmentation head
- BiSeNetHead(64, 128, num_classes, dropout_ratio), # aux4 segmentation head
- BiSeNetHead(128, 128, num_classes, dropout_ratio), # aux5_4 segmentation head
- BiSeNetHead(128, 128, num_classes, dropout_ratio), # main segmentation head
- ])
- self.denseaspp_head = _DenseASPPHead(128, 128)
- self.apply_nonlin = apply_nonlin
- def forward(self, x):
- img_sz = x.size()[2:]
- feat_d = self.detail(x)
- feat_out = self.segment(x)
- feat_out = list(feat_out)
- feat_out[-1] = self.denseaspp_head(self.bga(feat_d, feat_out[-1]))
- outputs = []
- if self.training:
- for i, head in enumerate(self.heads):
- out = head(feat_out[i], img_sz)
- outputs.append(out)
- else:
- out = self.heads[-1](feat_out[-1], img_sz)
- if self.apply_nonlin:
- out = F.softmax(out, dim=1)
- outputs.append(out)
- return outputs
- if __name__ == '__main__':
- dummy_data = torch.rand(2, 3, 256, 256)
- model = BiSeNetV2(num_classes=3)
- output = model(dummy_data)
- print(output)
|