resnext.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # D:/workplace/python
  2. # -*- coding: utf-8 -*-
  3. # @File :resnext.py
  4. # @Author:Guido LuXiaohao
  5. # @Date :2020/4/8
  6. # @Software:PyCharm
  7. '''ResNeXt models for Keras.
  8. This is a revised implementation from Somshubra Majumdar's SENet repo:
  9. (https://github.com/titu1994/keras-squeeze-excite-network)
  10. # Reference
  11. - [Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf))
  12. '''
  13. from __future__ import print_function
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. import warnings
  17. from keras.models import Model
  18. from keras.layers.core import Dense, Lambda
  19. from keras.layers.advanced_activations import LeakyReLU
  20. from keras.layers.convolutional import Conv2D
  21. from keras.layers.pooling import GlobalAveragePooling2D, GlobalMaxPooling2D, MaxPooling2D
  22. from keras.layers import Input
  23. from keras.layers.merge import concatenate, add
  24. from keras.layers.normalization import BatchNormalization
  25. from keras.regularizers import l2
  26. from keras.utils.layer_utils import convert_all_kernels_in_model
  27. from keras.utils.data_utils import get_file
  28. from keras.engine.topology import get_source_inputs
  29. from keras_applications.imagenet_utils import _obtain_input_shape
  30. import keras.backend as K
  31. from nets.attention_module.se_cbam import attach_attention_module
  32. CIFAR_TH_WEIGHTS_PATH = ''
  33. CIFAR_TF_WEIGHTS_PATH = ''
  34. CIFAR_TH_WEIGHTS_PATH_NO_TOP = ''
  35. CIFAR_TF_WEIGHTS_PATH_NO_TOP = ''
  36. IMAGENET_TH_WEIGHTS_PATH = ''
  37. IMAGENET_TF_WEIGHTS_PATH = ''
  38. IMAGENET_TH_WEIGHTS_PATH_NO_TOP = ''
  39. IMAGENET_TF_WEIGHTS_PATH_NO_TOP = ''
  40. def ResNext(input_shape=None,
  41. depth=29,
  42. cardinality=8,
  43. width=64,
  44. weight_decay=5e-4,
  45. include_top=True,
  46. weights=None,
  47. input_tensor=None,
  48. pooling=None,
  49. classes=10,
  50. attention_module=None):
  51. """Instantiate the ResNeXt architecture. Note that ,
  52. when using TensorFlow for best performance you should set
  53. `image_data_format="channels_last"` in your Keras config
  54. at ~/.keras/keras.json.
  55. The model are compatible with both
  56. TensorFlow and Theano. The dimension ordering
  57. convention used by the model is the one
  58. specified in your Keras config file.
  59. # Arguments
  60. depth: number or layers in the ResNeXt model. Can be an
  61. integer or a list of integers.
  62. cardinality: the size of the set of transformations
  63. width: multiplier to the ResNeXt width (number of filters)
  64. weight_decay: weight decay (l2 norm)
  65. include_top: whether to include the fully-connected
  66. layer at the top of the network.
  67. weights: `None` (random initialization)
  68. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
  69. to use as image input for the model.
  70. input_shape: optional shape tuple, only to be specified
  71. if `include_top` is False (otherwise the input shape
  72. has to be `(32, 32, 3)` (with `tf` dim ordering)
  73. or `(3, 32, 32)` (with `th` dim ordering).
  74. It should have exactly 3 inputs channels,
  75. and width and height should be no smaller than 8.
  76. E.g. `(200, 200, 3)` would be one valid value.
  77. pooling: Optional pooling mode for feature extraction
  78. when `include_top` is `False`.
  79. - `None` means that the output of the model will be
  80. the 4D tensor output of the
  81. last convolutional layer.
  82. - `avg` means that global average pooling
  83. will be applied to the output of the
  84. last convolutional layer, and thus
  85. the output of the model will be a 2D tensor.
  86. - `max` means that global max pooling will
  87. be applied.
  88. classes: optional number of classes to classify images
  89. into, only to be specified if `include_top` is True, and
  90. if no `weights` argument is specified.
  91. # Returns
  92. A Keras model instance.
  93. """
  94. if weights not in {'cifar10', None}:
  95. raise ValueError('The `weights` argument should be either '
  96. '`None` (random initialization) or `cifar10` '
  97. '(pre-training on CIFAR-10).')
  98. if weights == 'cifar10' and include_top and classes != 10:
  99. raise ValueError('If using `weights` as CIFAR 10 with `include_top`'
  100. ' as true, `classes` should be 10')
  101. if type(depth) == int:
  102. if (depth - 2) % 9 != 0:
  103. raise ValueError('Depth of the network must be such that (depth - 2)'
  104. 'should be divisible by 9.')
  105. # Determine proper input shape
  106. input_shape = _obtain_input_shape(input_shape,
  107. default_size=32,
  108. min_size=8,
  109. data_format=K.image_data_format(),
  110. require_flatten=include_top)
  111. if input_tensor is None:
  112. img_input = Input(shape=input_shape)
  113. else:
  114. if not K.is_keras_tensor(input_tensor):
  115. img_input = Input(tensor=input_tensor, shape=input_shape)
  116. else:
  117. img_input = input_tensor
  118. x = __create_res_next(classes, img_input, include_top, depth, cardinality, width,
  119. weight_decay, pooling, attention_module)
  120. # Ensure that the model takes into account
  121. # any potential predecessors of `input_tensor`.
  122. if input_tensor is not None:
  123. inputs = get_source_inputs(input_tensor)
  124. else:
  125. inputs = img_input
  126. # Create model.
  127. model = Model(inputs, x, name='resnext')
  128. return model
  129. def ResNextImageNet(input_shape=None,
  130. depth=[3, 4, 6, 3],
  131. cardinality=32,
  132. width=4,
  133. weight_decay=5e-4,
  134. include_top=True,
  135. weights=None,
  136. input_tensor=None,
  137. pooling=None,
  138. classes=1000,
  139. attention_module=None):
  140. """ Instantiate the ResNeXt architecture for the ImageNet dataset. Note that ,
  141. when using TensorFlow for best performance you should set
  142. `image_data_format="channels_last"` in your Keras config
  143. at ~/.keras/keras.json.
  144. The model are compatible with both
  145. TensorFlow and Theano. The dimension ordering
  146. convention used by the model is the one
  147. specified in your Keras config file.
  148. # Arguments
  149. depth: number or layers in the each block, defined as a list.
  150. ResNeXt-50 can be defined as [3, 4, 6, 3].
  151. ResNeXt-101 can be defined as [3, 4, 23, 3].
  152. Defaults is ResNeXt-50.
  153. cardinality: the size of the set of transformations
  154. width: multiplier to the ResNeXt width (number of filters)
  155. weight_decay: weight decay (l2 norm)
  156. include_top: whether to include the fully-connected
  157. layer at the top of the network.
  158. weights: `None` (random initialization) or `imagenet` (trained
  159. on ImageNet)
  160. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
  161. to use as image input for the model.
  162. input_shape: optional shape tuple, only to be specified
  163. if `include_top` is False (otherwise the input shape
  164. has to be `(224, 224, 3)` (with `tf` dim ordering)
  165. or `(3, 224, 224)` (with `th` dim ordering).
  166. It should have exactly 3 inputs channels,
  167. and width and height should be no smaller than 8.
  168. E.g. `(200, 200, 3)` would be one valid value.
  169. pooling: Optional pooling mode for feature extraction
  170. when `include_top` is `False`.
  171. - `None` means that the output of the model will be
  172. the 4D tensor output of the
  173. last convolutional layer.
  174. - `avg` means that global average pooling
  175. will be applied to the output of the
  176. last convolutional layer, and thus
  177. the output of the model will be a 2D tensor.
  178. - `max` means that global max pooling will
  179. be applied.
  180. classes: optional number of classes to classify images
  181. into, only to be specified if `include_top` is True, and
  182. if no `weights` argument is specified.
  183. # Returns
  184. A Keras model instance.
  185. """
  186. if weights not in {'imagenet', None}:
  187. raise ValueError('The `weights` argument should be either '
  188. '`None` (random initialization) or `imagenet` '
  189. '(pre-training on ImageNet).')
  190. if weights == 'imagenet' and include_top and classes != 1000:
  191. raise ValueError('If using `weights` as imagenet with `include_top`'
  192. ' as true, `classes` should be 1000')
  193. if type(depth) == int and (depth - 2) % 9 != 0:
  194. raise ValueError('Depth of the network must be such that (depth - 2)'
  195. 'should be divisible by 9.')
  196. # Determine proper input shape
  197. input_shape = _obtain_input_shape(input_shape,
  198. default_size=224,
  199. min_size=112,
  200. data_format=K.image_data_format(),
  201. require_flatten=include_top)
  202. if input_tensor is None:
  203. img_input = Input(shape=input_shape)
  204. else:
  205. if not K.is_keras_tensor(input_tensor):
  206. img_input = Input(tensor=input_tensor, shape=input_shape)
  207. else:
  208. img_input = input_tensor
  209. x = __create_res_next_imagenet(classes, img_input, include_top, depth, cardinality, width,
  210. weight_decay, pooling)
  211. # Ensure that the model takes into account
  212. # any potential predecessors of `input_tensor`.
  213. if input_tensor is not None:
  214. inputs = get_source_inputs(input_tensor)
  215. else:
  216. inputs = img_input
  217. # Create model.
  218. model = Model(inputs, x, name='resnext')
  219. return model
  220. def __initial_conv_block(input, weight_decay=5e-4):
  221. ''' Adds an initial convolution block, with batch normalization and relu activation
  222. Args:
  223. input: input tensor
  224. weight_decay: weight decay factor
  225. Returns: a keras tensor
  226. '''
  227. channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
  228. x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal',
  229. kernel_regularizer=l2(weight_decay))(input)
  230. x = BatchNormalization(axis=channel_axis)(x)
  231. x = LeakyReLU()(x)
  232. return x
  233. def __initial_conv_block_inception(input, weight_decay=5e-4):
  234. ''' Adds an initial conv block, with batch norm and relu for the inception resnext
  235. Args:
  236. input: input tensor
  237. weight_decay: weight decay factor
  238. Returns: a keras tensor
  239. '''
  240. channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
  241. x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal',
  242. kernel_regularizer=l2(weight_decay), strides=(2, 2))(input)
  243. x = BatchNormalization(axis=channel_axis)(x)
  244. x = LeakyReLU()(x)
  245. x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
  246. return x
  247. def __grouped_convolution_block(input, grouped_channels, cardinality, strides, weight_decay=5e-4):
  248. ''' Adds a grouped convolution block. It is an equivalent block from the paper
  249. Args:
  250. input: input tensor
  251. grouped_channels: grouped number of filters
  252. cardinality: cardinality factor describing the number of groups
  253. strides: performs strided convolution for downscaling if > 1
  254. weight_decay: weight decay term
  255. Returns: a keras tensor
  256. '''
  257. init = input
  258. channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
  259. group_list = []
  260. if cardinality == 1:
  261. # with cardinality 1, it is a standard convolution
  262. x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
  263. kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
  264. x = BatchNormalization(axis=channel_axis)(x)
  265. x = LeakyReLU()(x)
  266. return x
  267. for c in range(cardinality):
  268. x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels]
  269. if K.image_data_format() == 'channels_last' else
  270. lambda z: z[:, c * grouped_channels:(c + 1) * grouped_channels, :, :])(input)
  271. x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
  272. kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(x)
  273. group_list.append(x)
  274. group_merge = concatenate(group_list, axis=channel_axis)
  275. x = BatchNormalization(axis=channel_axis)(group_merge)
  276. x = LeakyReLU()(x)
  277. return x
  278. def __bottleneck_block(input, filters=64, cardinality=8, strides=1, weight_decay=5e-4, attention_module=None):
  279. ''' Adds a bottleneck block
  280. Args:
  281. input: input tensor
  282. filters: number of output filters
  283. cardinality: cardinality factor described number of
  284. grouped convolutions
  285. strides: performs strided convolution for downsampling if > 1
  286. weight_decay: weight decay factor
  287. Returns: a keras tensor
  288. '''
  289. init = input
  290. grouped_channels = int(filters / cardinality)
  291. channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
  292. # Check if input number of filters is same as 16 * k, else create convolution2d for this input
  293. if K.image_data_format() == 'channels_first':
  294. if init._keras_shape[1] != 2 * filters:
  295. init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides),
  296. use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
  297. init = BatchNormalization(axis=channel_axis)(init)
  298. else:
  299. if init._keras_shape[-1] != 2 * filters:
  300. init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides),
  301. use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
  302. init = BatchNormalization(axis=channel_axis)(init)
  303. x = Conv2D(filters, (1, 1), padding='same', use_bias=False,
  304. kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(input)
  305. x = BatchNormalization(axis=channel_axis)(x)
  306. x = LeakyReLU()(x)
  307. x = __grouped_convolution_block(x, grouped_channels, cardinality, strides, weight_decay)
  308. x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal',
  309. kernel_regularizer=l2(weight_decay))(x)
  310. x = BatchNormalization(axis=channel_axis)(x)
  311. # attention_module
  312. if attention_module is not None:
  313. x = attach_attention_module(x, attention_module)
  314. x = add([init, x])
  315. x = LeakyReLU()(x)
  316. return x
  317. def __create_res_next(nb_classes, img_input, include_top, depth=29, cardinality=8, width=4,
  318. weight_decay=5e-4, pooling=None, attention_module=None):
  319. ''' Creates a ResNeXt model with specified parameters
  320. Args:
  321. nb_classes: Number of output classes
  322. img_input: Input tensor or layer
  323. include_top: Flag to include the last dense layer
  324. depth: Depth of the network. Can be an positive integer or a list
  325. Compute N = (n - 2) / 9.
  326. For a depth of 56, n = 56, N = (56 - 2) / 9 = 6
  327. For a depth of 101, n = 101, N = (101 - 2) / 9 = 11
  328. cardinality: the size of the set of transformations.
  329. Increasing cardinality improves classification accuracy,
  330. width: Width of the network.
  331. weight_decay: weight_decay (l2 norm)
  332. pooling: Optional pooling mode for feature extraction
  333. when `include_top` is `False`.
  334. - `None` means that the output of the model will be
  335. the 4D tensor output of the
  336. last convolutional layer.
  337. - `avg` means that global average pooling
  338. will be applied to the output of the
  339. last convolutional layer, and thus
  340. the output of the model will be a 2D tensor.
  341. - `max` means that global max pooling will
  342. be applied.
  343. Returns: a Keras Model
  344. '''
  345. if type(depth) is list or type(depth) is tuple:
  346. # If a list is provided, defer to user how many blocks are present
  347. N = list(depth)
  348. else:
  349. # Otherwise, default to 3 blocks each of default number of group convolution blocks
  350. N = [(depth - 2) // 9 for _ in range(3)]
  351. filters = cardinality * width
  352. filters_list = []
  353. for i in range(len(N)):
  354. filters_list.append(filters)
  355. filters *= 2 # double the size of the filters
  356. x = __initial_conv_block(img_input, weight_decay)
  357. # block 1 (no pooling)
  358. for i in range(N[0]):
  359. x = __bottleneck_block(x, filters_list[0], cardinality, strides=1,
  360. weight_decay=weight_decay, attention_module=attention_module)
  361. N = N[1:] # remove the first block from block definition list
  362. filters_list = filters_list[1:] # remove the first filter from the filter list
  363. # block 2 to N
  364. for block_idx, n_i in enumerate(N):
  365. for i in range(n_i):
  366. if i == 0:
  367. x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=2,
  368. weight_decay=weight_decay, attention_module=attention_module)
  369. else:
  370. x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=1,
  371. weight_decay=weight_decay, attention_module=attention_module)
  372. if include_top:
  373. x = GlobalAveragePooling2D()(x)
  374. x = Dense(nb_classes, use_bias=False, kernel_regularizer=l2(weight_decay),
  375. kernel_initializer='he_normal', activation='softmax')(x)
  376. else:
  377. if pooling == 'avg':
  378. x = GlobalAveragePooling2D()(x)
  379. elif pooling == 'max':
  380. x = GlobalMaxPooling2D()(x)
  381. return x
  382. def __create_res_next_imagenet(nb_classes, img_input, include_top, depth, cardinality=32, width=4,
  383. weight_decay=5e-4, pooling=None, attention_module=None):
  384. ''' Creates a ResNeXt model with specified parameters
  385. Args:
  386. nb_classes: Number of output classes
  387. img_input: Input tensor or layer
  388. include_top: Flag to include the last dense layer
  389. depth: Depth of the network. List of integers.
  390. Increasing cardinality improves classification accuracy,
  391. width: Width of the network.
  392. weight_decay: weight_decay (l2 norm)
  393. pooling: Optional pooling mode for feature extraction
  394. when `include_top` is `False`.
  395. - `None` means that the output of the model will be
  396. the 4D tensor output of the
  397. last convolutional layer.
  398. - `avg` means that global average pooling
  399. will be applied to the output of the
  400. last convolutional layer, and thus
  401. the output of the model will be a 2D tensor.
  402. - `max` means that global max pooling will
  403. be applied.
  404. Returns: a Keras Model
  405. '''
  406. if type(depth) is list or type(depth) is tuple:
  407. # If a list is provided, defer to user how many blocks are present
  408. N = list(depth)
  409. else:
  410. # Otherwise, default to 3 blocks each of default number of group convolution blocks
  411. N = [(depth - 2) // 9 for _ in range(3)]
  412. filters = cardinality * width
  413. filters_list = []
  414. for i in range(len(N)):
  415. filters_list.append(filters)
  416. filters *= 2 # double the size of the filters
  417. x = __initial_conv_block_inception(img_input, weight_decay)
  418. # block 1 (no pooling)
  419. for i in range(N[0]):
  420. x = __bottleneck_block(x, filters_list[0], cardinality, strides=1,
  421. weight_decay=weight_decay, attention_module=attention_module)
  422. N = N[1:] # remove the first block from block definition list
  423. filters_list = filters_list[1:] # remove the first filter from the filter list
  424. # block 2 to N
  425. for block_idx, n_i in enumerate(N):
  426. for i in range(n_i):
  427. if i == 0:
  428. x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=2,
  429. weight_decay=weight_decay, attention_module=attention_module)
  430. else:
  431. x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=1,
  432. weight_decay=weight_decay, attention_module=attention_module)
  433. if include_top:
  434. x = GlobalAveragePooling2D()(x)
  435. x = Dense(nb_classes, use_bias=False, kernel_regularizer=l2(weight_decay),
  436. kernel_initializer='he_normal', activation='softmax')(x)
  437. else:
  438. if pooling == 'avg':
  439. x = GlobalAveragePooling2D()(x)
  440. elif pooling == 'max':
  441. x = GlobalMaxPooling2D()(x)
  442. return x