temp_mobilenetv3.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # D:/workplace/python
  2. # -*- coding: utf-8 -*-
  3. # @File :temp_mobilenetv3.py
  4. # @Author:Guido LuXiaohao
  5. # @Date :2020/4/9
  6. # @Software:PyCharm
  7. from keras.models import Model
  8. from keras.layers import Input, Conv2D, GlobalAveragePooling2D, Reshape, Dropout, Dense, MaxPooling2D, concatenate,AveragePooling2D
  9. from keras.utils.vis_utils import plot_model
  10. from nets.mobilenet.mobilenet_base import MobileNetBase
  11. from keras import backend as K
  12. from keras.engine.topology import Layer
  13. from keras.layers import activations, initializers, regularizers, constraints, Lambda
  14. from keras.engine import InputSpec
  15. import tensorflow as tf
  16. class AMSoftmax(Layer):
  17. def __init__(self, units, s, m,
  18. kernel_initializer='glorot_uniform',
  19. kernel_regularizer=None,
  20. kernel_constraint=None,
  21. **kwargs
  22. ):
  23. if 'input_shape' not in kwargs and 'input_dim' in kwargs:
  24. kwargs['input_shape'] = (kwargs.pop('input_dim'),)
  25. super(AMSoftmax, self).__init__(**kwargs)
  26. self.units = units
  27. self.s = s
  28. self.m = m
  29. self.kernel_initializer = initializers.get(kernel_initializer)
  30. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  31. self.kernel_constraint = constraints.get(kernel_constraint)
  32. self.input_spec = InputSpec(min_ndim=2)
  33. self.supports_masking = True
  34. def build(self, input_shape):
  35. assert len(input_shape) >= 2
  36. input_dim = input_shape[-1]
  37. self.kernel = self.add_weight(shape=(input_dim, self.units),
  38. initializer=self.kernel_initializer,
  39. name='kernel',
  40. regularizer=self.kernel_regularizer,
  41. constraint=self.kernel_constraint)
  42. self.bias = None
  43. self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
  44. self.built = True
  45. def call(self, inputs, **kwargs):
  46. inputs = tf.nn.l2_normalize(inputs, dim=-1)
  47. self.kernel = tf.nn.l2_normalize(self.kernel, dim=(0, 1)) # W归一化
  48. dis_cosin = K.dot(inputs, self.kernel)
  49. psi = dis_cosin - self.m
  50. e_costheta = K.exp(self.s * dis_cosin)
  51. e_psi = K.exp(self.s * psi)
  52. sum_x = K.sum(e_costheta, axis=-1, keepdims=True)
  53. temp = e_psi - e_costheta
  54. temp = temp + sum_x
  55. output = e_psi / temp
  56. return output
  57. def amsoftmax_loss(y_true, y_pred):
  58. d1 = K.sum(y_true * y_pred, axis=-1)
  59. d1 = K.log(K.clip(d1, K.epsilon(), None))
  60. loss = -K.mean(d1, axis=-1)
  61. return loss
  62. class MobileNetV3_Small(MobileNetBase):
  63. def __init__(self, shape, n_class, alpha=1.0, include_top=True):
  64. """Init.
  65. # Arguments
  66. input_shape: An integer or tuple/list of 3 integers, shape
  67. of input tensor.
  68. n_class: Integer, number of classes.
  69. alpha: Integer, width multiplier.
  70. include_top: if inculde classification layer.
  71. # Returns
  72. MobileNetv3 model.
  73. """
  74. super(MobileNetV3_Small, self).__init__(shape, n_class, alpha)
  75. self.include_top = include_top
  76. def build(self, plot=False):
  77. """build MobileNetV3 Small.
  78. # Arguments
  79. plot: Boolean, weather to plot model.
  80. # Returns
  81. model: Model, model.
  82. """
  83. inputs = Input(shape=self.shape)
  84. x = self._conv_block(inputs, 16, (3, 3), strides=(2, 2), nl=None)
  85. x = self._bottleneck(x, 16, (3, 3), e=16, s=2, squeeze=True, nl=None)
  86. x = self._bottleneck(x, 24, (3, 3), e=72, s=2, squeeze=False, nl=None)
  87. x = self._bottleneck(x, 24, (3, 3), e=88, s=1, squeeze=False, nl=None)
  88. x = self._bottleneck(x, 40, (5, 5), e=96, s=2, squeeze=True, nl=None)
  89. x = self._bottleneck(x, 40, (5, 5), e=240, s=1, squeeze=True, nl=None)
  90. x = self._bottleneck(x, 40, (5, 5), e=240, s=1, squeeze=True, nl=None)
  91. x = self._bottleneck(x, 48, (5, 5), e=120, s=1, squeeze=True, nl=None)
  92. x = self._bottleneck(x, 48, (5, 5), e=144, s=1, squeeze=True, nl=None)
  93. x = self._bottleneck(x, 96, (5, 5), e=288, s=2, squeeze=True, nl=None)
  94. x = self._bottleneck(x, 96, (5, 5), e=576, s=1, squeeze=True, nl=None)
  95. x = self._bottleneck(x, 96, (5, 5), e=576, s=1, squeeze=True, nl=None)
  96. x = self._conv_block(x, 576, (1, 1), strides=(1, 1), nl=None)
  97. x = GlobalAveragePooling2D()(x)
  98. x = Reshape((1, 1, 576))(x)
  99. x = Conv2D(1280, (1, 1), padding='same')(x)
  100. x = self._return_activation(x, None)
  101. # x = Dropout(0.5)(x)
  102. if self.include_top:
  103. x = Conv2D(self.n_class, (1, 1), padding='same', activation='softmax')(x)
  104. x = Reshape((self.n_class,), name='reshape_11')(x)
  105. model = Model(inputs, x)
  106. model.summary()
  107. if plot:
  108. plot_model(model, to_file='images/MobileNetv3_small.png', show_shapes=True)
  109. return model
  110. # A = MobileNetV3_Small((256, 256, 3), 4)
  111. # model = A.build()
  112. # model.save(r"C:\Users\Administrator\Desktop\1.h5")