123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- # -*- coding: utf-8 -*-
- # @Time : 2022/10/20 16:25
- # @Author : Marvin.yuan
- # @File : pp_liteseg.py
- # @Description :
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .components.decoders import BaseSegDecoder
- from .layers import (ConvBN, ConvBNAct, CustomAdaptiveAvgPool2d,
- custom_adaptive_avg_pool2d, custom_adaptive_max_pool2d)
- from utils import registry
- @registry.MODELS.register_module
- class PPLiteSegNeck(nn.Module):
- def __init__(self, arm_in_chs, in_index, arm_out_chs, arm_type,
- cm_bin_sizes, cm_out_ch, resize_mode='bilinear'):
- super().__init__()
- self.in_index = in_index if isinstance(in_index, list) else [in_index]
- # PP Context Module
- self.cm = PPContextModule(arm_in_chs[-1], cm_out_ch, cm_out_ch, cm_bin_sizes)
- # Arm Module
- arm_class = eval(arm_type)
- self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32]
- for i in range(len(arm_in_chs)):
- low_chs = arm_in_chs[i]
- high_ch = cm_out_ch if i == len(
- arm_in_chs) - 1 else arm_out_chs[i + 1]
- out_ch = arm_out_chs[i]
- arm = arm_class(
- low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode)
- self.arm_list.append(arm)
- def forward(self, inputs):
- in_feat_list = [inputs[idx] for idx in self.in_index]
- high_feat = self.cm(in_feat_list[-1])
- out_feat_list = []
- for i in reversed(range(len(in_feat_list))):
- low_feat = in_feat_list[i]
- arm = self.arm_list[i]
- high_feat = arm(low_feat, high_feat)
- out_feat_list.insert(0, high_feat)
- return out_feat_list
- @registry.MODELS.register_module
- class PPLiteSegDecoder(BaseSegDecoder):
- """The Decoder of PPLiteSeg."""
- def __init__(self, arm_out_chs, arm_type, cm_bin_sizes, cm_out_ch, resize_mode='bilinear', **kwargs):
- super().__init__(**kwargs)
- # PP Context Module
- self.cm = PPContextModule(self.in_channels[-1], cm_out_ch, cm_out_ch, cm_bin_sizes)
- # Arm Module
- arm_class = eval(arm_type)
- self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32]
- for i in range(len(self.in_channels)):
- low_chs = self.in_channels[i]
- high_ch = cm_out_ch if i == len(
- self.in_channels) - 1 else arm_out_chs[i + 1]
- out_ch = arm_out_chs[i]
- arm = arm_class(
- low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode)
- self.arm_list.append(arm)
- def _forward_features(self, inputs):
- """
- Args:
- inputs (List(Tensor)): Such as [x2, x4, x8, x16, x32].
- x2, x4 and x8 are optional.
- Returns:
- out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32].
- x2, x4 and x8 are optional.
- The length of in_feat_list and out_feat_list are the same.
- """
- in_feat_list = self._transform_inputs(inputs)
- high_feat = self.cm(in_feat_list[-1])
- out_feat_list = []
- for i in reversed(range(len(in_feat_list))):
- low_feat = in_feat_list[i]
- arm = self.arm_list[i]
- high_feat = arm(low_feat, high_feat)
- out_feat_list.insert(0, high_feat)
- return out_feat_list
- def forward(self, inputs, data_samples=None):
- x = self._forward_features(inputs)
- x = self.decode_head(x)
- return x
- class UAFM(nn.Module):
- """
- The base of Unified Attention Fusion Module.
- Args:
- x_ch (int): The channel of x tensor, which is the low level feature.
- y_ch (int): The channel of y tensor, which is the high level feature.
- out_ch (int): The channel of output tensor.
- ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
- resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
- """
- def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
- super().__init__()
- self.conv_x = ConvBNAct(x_ch, y_ch, kernel_size=ksize)
- self.conv_out = ConvBNAct(y_ch, out_ch, kernel_size=3)
- self.resize_mode = resize_mode
- def check(self, x, y):
- assert x.ndim == 4 and y.ndim == 4
- x_h, x_w = x.shape[2:]
- y_h, y_w = y.shape[2:]
- assert x_h >= y_h and x_w >= y_w
- def prepare(self, x, y):
- x = self.prepare_x(x, y)
- y = self.prepare_y(x, y)
- return x, y
- def prepare_x(self, x, y):
- x = self.conv_x(x)
- return x
- def prepare_y(self, x, y):
- y_up = F.interpolate(y, x.shape[-2:], mode=self.resize_mode)
- return y_up
- def fuse(self, x, y):
- out = x + y
- out = self.conv_out(out)
- return out
- def forward(self, x, y):
- """
- Args:
- x (Tensor): The low level feature.
- y (Tensor): The high level feature.
- """
- self.check(x, y)
- x, y = self.prepare(x, y)
- out = self.fuse(x, y)
- return out
- class UAFM_ChAtten(UAFM):
- def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
- """
- The UAFM with channel attention, which uses mean and max values.
- Args:
- x_ch (int): The channel of x tensor, which is the low level feature.
- y_ch (int): The channel of y tensor, which is the high level feature.
- out_ch (int): The channel of output tensor.
- ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
- resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
- """
- super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
- self.conv_xy_atten = nn.Sequential(
- ConvBNAct(4 * y_ch, y_ch // 2, kernel_size=1),
- ConvBN(y_ch // 2, y_ch, kernel_size=1))
- @staticmethod
- def avg_max_reduce_hw(x):
- # Reduce hw by avg and max
- # Return cat([avg_pool_0, avg_pool_1, ..., max_pool_0, max_pool_1, ...])
- if not isinstance(x, (list, tuple)):
- x = [x]
- res_avg = []
- res_max = []
- for xi in x:
- avg = custom_adaptive_avg_pool2d(xi, 1)
- max = custom_adaptive_max_pool2d(xi, 1)
- res_avg.append(avg)
- res_max.append(max)
- res = res_avg + res_max
- return torch.cat(res, dim=1)
- def fuse(self, x, y):
- """
- Args:
- x (Tensor): The low level feature.
- y (Tensor): The high level feature.
- """
- atten = self.avg_max_reduce_hw([x, y])
- atten = torch.sigmoid(self.conv_xy_atten(atten))
- out = x * atten + y * (1 - atten)
- out = self.conv_out(out)
- return out
- class UAFM_ChAtten_S(UAFM):
- def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
- """
- The UAFM with channel attention, which uses mean values.
- Args:
- x_ch (int): The channel of x tensor, which is the low level feature.
- y_ch (int): The channel of y tensor, which is the high level feature.
- out_ch (int): The channel of output tensor.
- ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
- resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
- """
- super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
- self.conv_xy_atten = nn.Sequential(
- ConvBNAct(2 * y_ch, y_ch // 2, kernel_size=1),
- ConvBN(y_ch // 2, y_ch, kernel_size=1))
- @staticmethod
- def avg_reduce_hw(x):
- # Reduce hw by avg
- # Return cat([avg_pool_0, avg_pool_1, ...])
- if not isinstance(x, (list, tuple)):
- return custom_adaptive_avg_pool2d(x, 1)
- elif len(x) == 1:
- return custom_adaptive_avg_pool2d(x[0], 1)
- else:
- res = []
- for xi in x:
- res.append(custom_adaptive_avg_pool2d(xi, 1))
- return torch.concat(res, dim=1)
- def fuse(self, x, y):
- """
- Args:
- x (Tensor): The low level feature.
- y (Tensor): The high level feature.
- """
- atten = self.avg_reduce_hw([x, y])
- atten = torch.sigmoid(self.conv_xy_atten(atten))
- out = x * atten + y * (1 - atten)
- out = self.conv_out(out)
- return out
- class UAFM_SpAtten(UAFM):
- """
- The UAFM with spatial attention, which uses mean and max values.
- Args:
- x_ch (int): The channel of x tensor, which is the low level feature.
- y_ch (int): The channel of y tensor, which is the high level feature.
- out_ch (int): The channel of output tensor.
- ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
- resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
- """
- def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
- super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
- self.conv_xy_atten = nn.Sequential(
- ConvBNAct(4, 2, kernel_size=3),
- ConvBN(2, 1, kernel_size=3))
- @staticmethod
- def avg_max_reduce_channel(x):
- # Reduce hw by avg and max
- # Return cat([avg_ch_0, max_ch_0, avg_ch_1, max_ch_1, ...])
- if not isinstance(x, (list, tuple)):
- x = [x]
- res = []
- for xi in x:
- mean_value = torch.mean(xi, dim=1, keepdim=True)
- max_vaule = torch.max(xi, dim=1, keepdim=True)[0]
- res.extend([mean_value, max_vaule])
- return torch.cat(res, dim=1)
- def fuse(self, x, y):
- """
- Args:
- x (Tensor): The low level feature.
- y (Tensor): The high level feature.
- """
- atten = self.avg_max_reduce_channel([x, y])
- atten = torch.sigmoid(self.conv_xy_atten(atten))
- out = x * atten + y * (1 - atten)
- out = self.conv_out(out)
- return out
- class UAFM_SpAtten_S(UAFM):
- """
- The UAFM with spatial attention, which uses mean values.
- Args:
- x_ch (int): The channel of x tensor, which is the low level feature.
- y_ch (int): The channel of y tensor, which is the high level feature.
- out_ch (int): The channel of output tensor.
- ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
- resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
- """
- def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'):
- super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
- self.conv_xy_atten = nn.Sequential(
- ConvBNAct(2, 2, kernel_size=3),
- ConvBN(2, 1, kernel_size=3))
- @staticmethod
- def avg_reduce_channel(x):
- # Reduce channel by avg
- # Return cat([avg_ch_0, avg_ch_1, ...])
- if not isinstance(x, (list, tuple)):
- return torch.mean(x, dim=1, keepdim=True)
- elif len(x) == 1:
- return torch.mean(x[0], dim=1, keepdim=True)
- else:
- res = []
- for xi in x:
- res.append(torch.mean(xi, dim=1, keepdim=True))
- return torch.concat(res, dim=1)
- def fuse(self, x, y):
- """
- Args:
- x (Tensor): The low level feature.
- y (Tensor): The high level feature.
- """
- atten = self.avg_reduce_channel([x, y])
- atten = torch.sigmoid(self.conv_xy_atten(atten))
- out = x * atten + y * (1 - atten)
- out = self.conv_out(out)
- return
- class PPContextModule(nn.Module):
- def __init__(self,
- in_channels: int,
- inter_channels: int,
- out_channels: int,
- bin_sizes: tuple,
- align_corners: bool = False):
- super().__init__()
- self.stages = nn.ModuleList([
- self._make_stage(in_channels, inter_channels, size)
- for size in bin_sizes
- ])
- self.conv_out = ConvBNAct(inter_channels, out_channels, kernel_size=3)
- self.align_corners = align_corners
- def _make_stage(self, in_channels, out_channels, size):
- prior = CustomAdaptiveAvgPool2d(size)
- conv = ConvBNAct(in_channels, out_channels, kernel_size=1)
- return nn.Sequential(prior, conv)
- def forward(self, input):
- out = None
- input_shape = input.shape[-2:]
- for stage in self.stages:
- x = stage(input)
- x = F.interpolate(
- x,
- input_shape,
- mode='bilinear',
- align_corners=self.align_corners)
- if out is None:
- out = x
- else:
- out += x
- out = self.conv_out(out)
- return out
|