train.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # D:/workplace/python
  2. # -*- coding: utf-8 -*-
  3. # @File :train.py
  4. # @Author:Guido LuXiaohao
  5. # @Date :2020/4/8
  6. # @Software:PyCharm
  7. import numpy as np
  8. import tensorflow as tf
  9. from keras.callbacks import ReduceLROnPlateau
  10. from keras.optimizers import Adam, SGD, RMSprop
  11. from keras.callbacks import ModelCheckpoint
  12. from keras.models import Model
  13. from keras.callbacks import LearningRateScheduler
  14. from keras.layers import Dense, GlobalAveragePooling2D, Dropout
  15. from dataset.class_dataloader import load_sample, DataGenerator, myDataloderV2, MydatasetV2
  16. from dataset.augmention import get_training_augmentation, get_valing_augmentation
  17. from nets.temp_mobilenetv3 import MobileNetV3_Small
  18. from nets.temp_ghost import GhostModel
  19. from nets.temp_multinetV1 import mutinet
  20. from nets.densenet.densenet import DenseNet, DenseNetImageNet121, DenseNetImageNet161, DenseNetImageNet169, \
  21. DenseNetImageNet201, DenseNetImageNet264
  22. from nets.efficientnets.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, \
  23. EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
  24. config = tf.ConfigProto()
  25. config.gpu_options.allow_growth = True
  26. sess = tf.Session(config=config)
  27. def lr_schedule(epoch):#定义一个函数,接受轮数为参数, 返回学习率
  28. """Learning Rate Schedule
  29. Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
  30. Called automatically every epoch as part of callbacks during training.
  31. # Arguments
  32. epoch (int): The number of epochs
  33. # Returns
  34. lr (float32): learning rate
  35. """
  36. lr = 1e-3
  37. if epoch > 1800: # 0.5e-6
  38. lr *= 0.5e-3
  39. elif epoch > 100: # 1e-6
  40. lr *= 1e-3
  41. elif epoch > 50: # 1e-5
  42. lr *= 1e-1
  43. elif epoch > 15: # 1e-4
  44. lr *= 5e-1
  45. print('Learning rate: ', lr)
  46. return lr
  47. def add_new_last_layer(base_model, nb_classes):
  48. x = base_model.output
  49. x = GlobalAveragePooling2D()(x)
  50. x = Dense(1024, activation='relu')(x)
  51. x = Dense(512, activation='relu')(x)
  52. x = Dense(128, activation='relu')(x)
  53. x = Dropout(0.5)(x)
  54. predictions = Dense(nb_classes, activation='softmax')(x)
  55. model = Model(input=base_model.input, output=predictions)
  56. return model
  57. if __name__ == '__main__':
  58. # Training parameters超参数
  59. batch_size = 2
  60. epochs = 3000
  61. num_classes = 4
  62. model_save_path = r""
  63. model_name = 'model.{epoch:04d}-{val_acc:.5f}.h5'
  64. train_file_path = r'C:\Users\VINNO\Desktop\1111111\train'
  65. test_file_path = r'C:\Users\VINNO\Desktop\1111111\val'
  66. # 按文件夹生成标签
  67. trainX, trainY = load_sample(train_file_path)
  68. testX, testY = load_sample(test_file_path)
  69. traindata = MydatasetV2(trainX, trainY, augmentation=get_training_augmentation(), size=(160, 160))
  70. valdata = MydatasetV2(testX, testY, augmentation=get_valing_augmentation(), size=(160, 160))
  71. # training_generator = DataGenerator(trainX, trainY, Is_train=True, shuffle=True, **params)
  72. # testing_generator = DataGenerator(testX, testY, Is_train=False, shuffle=False, **params)
  73. training_generator = myDataloderV2(traindata, n_classes=num_classes, batch_size=batch_size, shuffle=True, train=False)
  74. testing_generator = myDataloderV2(valdata, n_classes=num_classes, batch_size=batch_size, shuffle=False, train=False)
  75. A = MobileNetV3_Small((160, 160, 3), 4, alpha=1.0, include_top=True)
  76. model = A.build()
  77. # 备选模型
  78. # model = DenseNet(input_shape=(256, 256, 3), depth=40, classes=10, activation='softmax', attention_module='se_block')
  79. # model = DenseNetImageNet121(input_shape=(256, 256, 3), classes=10, activation='softmax', attention_module='cbam_block')
  80. # model = DenseNetImageNet161(input_shape=(256, 256, 3), classes=10, activation='softmax', attention_module=None)
  81. # model = DenseNetImageNet169(input_shape=(256, 256, 3), classes=10, activation='softmax', attention_module=None)
  82. # model = DenseNetImageNet201(input_shape=(256, 256, 3), classes=10, activation='softmax', attention_module=None)
  83. # model = DenseNetImageNet264(input_shape=(256, 256, 3), classes=10, activation='softmax', attention_module=None)
  84. # model = EfficientNetB0(input_shape=(256, 256, 3), weights=None, classes=10)
  85. # model = EfficientNetB1(input_shape=(256, 256, 3), weights=None, classes=10)
  86. # model = EfficientNetB2(input_shape=(256, 256, 3), weights=None, classes=10)
  87. # model = EfficientNetB3(input_shape=(256, 256, 3), weights=None, classes=10)
  88. # model = EfficientNetB4(input_shape=(256, 256, 3), weights=None, classes=10)
  89. # model = EfficientNetB5(input_shape=(256, 256, 3), weights=None, classes=10)
  90. # model = EfficientNetB6(input_shape=(256, 256, 3), weights=None, classes=10)
  91. # model = EfficientNetB7(input_shape=(256, 256, 3), weights=None, classes=10)
  92. lr_scheduler = LearningRateScheduler(lr_schedule)
  93. lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1), # 学习率减少器, 学不动的时候降低学习率
  94. cooldown=0,
  95. patience=5,
  96. min_lr=0.5e-6)
  97. opt = Adam(lr=lr_schedule(0), decay=lr_schedule(0) / epochs)
  98. model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=lr_schedule(0)),
  99. metrics=["accuracy"])
  100. filepath = model_save_path + model_name
  101. checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True,
  102. mode='max')
  103. # 回调学习率
  104. callbacks_list = [checkpoint, lr_reducer, lr_scheduler]
  105. # train the network
  106. print("[INFO] training network...")
  107. H = model.fit_generator(generator=training_generator,
  108. validation_data=testing_generator, steps_per_epoch=len(trainX)//batch_size,
  109. epochs=epochs, callbacks=callbacks_list, verbose=1)