mobilenet_base.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # D:/workplace/python
  2. # -*- coding: utf-8 -*-
  3. # @File :mobilenet_base.py
  4. # @Author:Guido LuXiaohao
  5. # @Date :2020/4/8
  6. # @Software:PyCharm
  7. from keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D
  8. from keras.layers import Activation, BatchNormalization, Add, Multiply, Reshape
  9. from keras import backend as K
  10. class MobileNetBase:
  11. def __init__(self, shape, n_class, alpha=1.0):
  12. """Init
  13. # Arguments
  14. input_shape: An integer or tuple/list of 3 integers, shape
  15. of input tensor.
  16. n_class: Integer, number of classes.
  17. alpha: Integer, width multiplier.
  18. """
  19. self.shape = shape
  20. self.n_class = n_class
  21. self.alpha = alpha
  22. def _relu6(self, x):
  23. """Relu 6
  24. """
  25. return K.relu(x, max_value=6.0)
  26. def _hard_swish(self, x):
  27. """Hard swish
  28. """
  29. return x * K.relu(x + 3.0, max_value=6.0) / 6.0
  30. def _return_activation(self, x, nl):
  31. """Convolution Block
  32. This function defines a activation choice.
  33. # Arguments
  34. x: Tensor, input tensor of conv layer.
  35. nl: String, nonlinearity activation type.
  36. # Returns
  37. Output tensor.
  38. """
  39. if nl == 'HS':
  40. x = Activation(self._hard_swish)(x)
  41. if nl == 'RE':
  42. x = Activation(self._relu6)(x)
  43. if nl == None:
  44. x = Activation('relu')(x)
  45. return x
  46. def _conv_block(self, inputs, filters, kernel, strides, nl):
  47. """Convolution Block
  48. This function defines a 2D convolution operation with BN and activation.
  49. # Arguments
  50. inputs: Tensor, input tensor of conv layer.
  51. filters: Integer, the dimensionality of the output space.
  52. kernel: An integer or tuple/list of 2 integers, specifying the
  53. width and height of the 2D convolution window.
  54. strides: An integer or tuple/list of 2 integers,
  55. specifying the strides of the convolution along the width and height.
  56. Can be a single integer to specify the same value for
  57. all spatial dimensions.
  58. nl: String, nonlinearity activation type.
  59. # Returns
  60. Output tensor.
  61. """
  62. channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
  63. x = Conv2D(filters, kernel, padding='same', strides=strides)(inputs)
  64. x = BatchNormalization(axis=channel_axis)(x)
  65. return self._return_activation(x, nl)
  66. def _squeeze(self, inputs):
  67. """Squeeze and Excitation.
  68. This function defines a squeeze structure.
  69. # Arguments
  70. inputs: Tensor, input tensor of conv layer.
  71. """
  72. input_channels = int(inputs.shape[-1])
  73. x = GlobalAveragePooling2D()(inputs)
  74. x = Dense(input_channels, activation='relu')(x)
  75. x = Dense(input_channels, activation='hard_sigmoid')(x)
  76. x = Reshape((1, 1, input_channels))(x)
  77. x = Multiply()([inputs, x])
  78. return x
  79. def _bottleneck(self, inputs, filters, kernel, e, s, squeeze, nl):
  80. """Bottleneck
  81. This function defines a basic bottleneck structure.
  82. # Arguments
  83. inputs: Tensor, input tensor of conv layer.
  84. filters: Integer, the dimensionality of the output space.
  85. kernel: An integer or tuple/list of 2 integers, specifying the
  86. width and height of the 2D convolution window.
  87. e: Integer, expansion factor.
  88. t is always applied to the input size.
  89. s: An integer or tuple/list of 2 integers,specifying the strides
  90. of the convolution along the width and height.Can be a single
  91. integer to specify the same value for all spatial dimensions.
  92. squeeze: Boolean, Whether to use the squeeze.
  93. nl: String, nonlinearity activation type.
  94. # Returns
  95. Output tensor.
  96. """
  97. channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
  98. input_shape = K.int_shape(inputs)
  99. tchannel = int(e)
  100. cchannel = int(self.alpha * filters)
  101. r = s == 1 and input_shape[3] == filters
  102. x = self._conv_block(inputs, tchannel, (1, 1), (1, 1), nl)
  103. x = DepthwiseConv2D(kernel, strides=(s, s), depth_multiplier=1, padding='same')(x)
  104. x = BatchNormalization(axis=channel_axis)(x)
  105. x = self._return_activation(x, nl)
  106. if squeeze:
  107. x = self._squeeze(x)
  108. x = Conv2D(cchannel, (1, 1), strides=(1, 1), padding='same')(x)
  109. x = BatchNormalization(axis=channel_axis)(x)
  110. if r:
  111. x = Add()([x, inputs])
  112. return x
  113. def build(self):
  114. pass