train_resnet_43.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. from __future__ import print_function # Use a function definition from future version (say 3.x from 2.7 interpreter)
  2. from cntk.initializer import he_normal, normal
  3. from cntk.layers import AveragePooling, MaxPooling, BatchNormalization, Convolution, Dense
  4. from cntk.ops import element_times, relu
  5. from typing import Any, Union, Tuple
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import os
  9. import cntk as C
  10. import cv2
  11. import time
  12. import shutil
  13. from resnet_model_12 import conv_bn, conv_bn_relu, resnet_basic, resnet_basic_inc, resnet_basic_stack, resnet_bottleneck, \
  14. resnet_bottleneck_inc, resnet_bottleneck_stack, resnet_12
  15. from collections import OrderedDict
  16. #图片及标签数据
  17. image_height = 96
  18. image_width = 96
  19. num_channels = 3
  20. num_classes = 2
  21. EPOCH_SIZE = 78789
  22. MINIBATCH_SIZE= 70
  23. INDEX=43
  24. CHECKPOINT = 'checkpoint/checkpointed.tmp'
  25. LEARNING_RATIO = C.learning_parameter_schedule([0.005] * 70 + [0.001] * 70 + [0.0001],epoch_size = EPOCH_SIZE)
  26. L2_REG=0.02
  27. is_break=False
  28. import cntk.io.transforms as xforms
  29. def create_reader(map_file, mean_file, train):
  30. """
  31. 根据txt数据准备可供model读取的数据类型
  32. :param map_file: 存放数据路径的txt
  33. :param mean_file: 所有图片的均值文件(此处没有用到)
  34. :param train: bool,是否训练
  35. :return: 得到batch_source,可通过其得到图片内部数据及标签
  36. """
  37. transforms = [] # list
  38. if train:
  39. transforms += [xforms.crop(crop_type='randomside', side_ratio=1.0)]
  40. transforms += [xforms.scale(width=image_width, height=image_height, channels=num_channels),
  41. #xforms.mean(mean_file)
  42. ]
  43. transforms += [
  44. xforms.color(brightness_radius=0.3, contrast_radius=0.0, saturation_radius=0.0)
  45. ]
  46. # deserializer
  47. features = C.io.StreamDef(field='image', transforms=transforms)
  48. labels = C.io.StreamDef(field='label', shape=num_classes)
  49. streamdefs = C.io.StreamDefs(
  50. features=features, # first column in map file is referred to as 'image'
  51. labels=labels # and second as 'label'
  52. )
  53. imagedeserialized = C.io.ImageDeserializer(map_file, streamdefs)
  54. batch_source = C.io.MinibatchSource(imagedeserialized)
  55. return batch_source
  56. # ValidateVisual(reader_validate, reader_test, trainer, minibatch_evaluation_average, 3000, input_var,
  57. # label_var,
  58. # batch_index, plot_data_validate, plot_data_DV, plot_data_test)
  59. def ValidateVisual(reader_validate, reader_test, trainer, minibatch_evaluation_average, epoch_size, input_var,
  60. label_var, batch_index, plot_data_validate, plot_data_DV, plot_data_test):
  61. """
  62. 用于在训练的epoch中得出实时的测试集上的效果
  63. :param reader_validate:验证集txt
  64. :param reader_test:测试机txt
  65. :param trainer:训练过程中的trainer
  66. :param minibatch_evaluation_average:一个minibatch的训练集的误差率
  67. :param epoch_size:一共要测试的图片数目
  68. :param input_var:传入的占位符,用来存放图像内的数据
  69. :param label_var:传入的占位符,用来存放图像的标签
  70. :param batch_index:绘图的横坐标,训练的epoch
  71. :param plot_data_validate:dict,存放validate数据集的训练效果
  72. :param plot_data_DV:validate与训练集minibatch效果的差值
  73. :param plot_data_test:dict,存放test数据集的训练效果
  74. :return:None
  75. """
  76. minibatch_size = 64 #一次送进model训练的数目
  77. metric_numer = 0 #统计推理错误的数目
  78. metric_denom = 0 #统计总共的测试数目
  79. sample_count = 0 #统计目前的测试数目
  80. input_map_validate = {
  81. input_var: reader_validate.streams.features,
  82. label_var: reader_validate.streams.labels
  83. }
  84. while sample_count < epoch_size:
  85. current_minibatch = min(minibatch_size, epoch_size - sample_count)
  86. # Fetch next test min batch.
  87. data = reader_validate.next_minibatch(current_minibatch, input_map=input_map_validate)
  88. # minibatch data to be trained with
  89. metric_numer += trainer.test_minibatch(data) * current_minibatch
  90. metric_denom += current_minibatch
  91. # Keep track of the number of samples processed so far.
  92. sample_count += data[label_var].num_samples # 对训练过的样本进行统计
  93. plot_data_validate['batchindex'].append(batch_index)
  94. plot_data_validate['error'].append(float(metric_numer / metric_denom))
  95. error_validate = float((metric_numer) / metric_denom)
  96. error_DV = minibatch_evaluation_average - error_validate
  97. plot_data_DV['batchindex'].append(batch_index)
  98. plot_data_DV['error_DV'].append(error_DV)
  99. print("error_validate is " + str(error_validate * 100) + "%")
  100. print("error_DV is " + str(error_DV * 100) + "%")
  101. metric_numer_test = 0
  102. metric_denom_test = 0
  103. sample_count_test = 0
  104. input_map_test = {
  105. input_var: reader_test.streams.features,
  106. label_var: reader_test.streams.labels
  107. }
  108. while sample_count_test < epoch_size:
  109. # while sample_count_test < 6:
  110. current_minibatch = min(minibatch_size, epoch_size - sample_count_test)
  111. # current_minibatch = 6
  112. # Fetch next test min batch.
  113. data_test = reader_test.next_minibatch(current_minibatch, input_map=input_map_test) # 是否是随机截取
  114. # minibatch data to be trained with
  115. metric_numer_test += trainer.test_minibatch(data_test) * current_minibatch
  116. metric_denom_test += current_minibatch
  117. # Keep track of the number of samples processed so far.
  118. sample_count_test += data_test[label_var].num_samples # 对训练过的样本进行统计
  119. plot_data_test['batchindex'].append(batch_index)
  120. plot_data_test['error'].append(float(metric_numer_test / metric_denom_test))
  121. error_test = float((metric_numer_test) / metric_denom_test)
  122. print("error_test is " + str(error_test * 100) + "%") # 打印出
  123. return error_validate ,error_test
  124. def VisualValidate(plot_data, plot_data_validate, plot_data_DV, plot_data_test):
  125. """
  126. 最后画出整个训练曲线图
  127. :param plot_data: minibatch个训练集样本的训练曲线
  128. :param plot_data_validate: 存放validate集曲线
  129. :param plot_data_DV: plot_data与plot_data_validate的差值
  130. :param plot_data_test: 存放test集曲线
  131. :return: None
  132. """
  133. plt.subplot(211)
  134. plt.plot(plot_data["batchindex"], moving_average(plot_data["error"]), 'r--',
  135. plot_data_validate["batchindex"], moving_average(plot_data_validate["error"]), 'b--',
  136. plot_data_test["batchindex"], moving_average(plot_data_test["error"]), 'g--')
  137. plt.xlabel('Minibatch number')
  138. plt.ylabel('Label Prediction Error')
  139. plt.title('resnet'+str(INDEX))
  140. plt.savefig('train_line.jpg')
  141. plt.show()
  142. # plt.subplot(212)
  143. # plt.plot(plot_data_DV["batchindex"], moving_average(plot_data_DV["error_DV"]), 'g--')
  144. # plt.xlabel('Minibatch number')
  145. # plt.ylabel('Label Prediction Error_DV')
  146. # plt.title('DV')
  147. # plt.show()
  148. def moving_average(a, w=10):
  149. """
  150. 对测试中的误差进行求取平均,以达到曲线平滑的效果
  151. :param a: 训练曲线存放的误差数据,list
  152. :param w: 前后w个值求取平均
  153. :return: 求取均值后的误差数据
  154. """
  155. if len(a) < w:
  156. return a[:]
  157. return [val if idx < w else sum(a[(idx - w):idx]) / w for idx, val in enumerate(a)]
  158. def train_and_evaluate(reader_train, reader_validate, reader_test, max_epochs, create_basic_model):
  159. """
  160. 主训练程序
  161. :param reader_train:存放训练集的txt
  162. :param reader_validate:验证集
  163. :param reader_test:测试集
  164. :param max_epochs:迭代次数
  165. :param create_basic_model:model构造
  166. :return:训练好的model
  167. """
  168. # Input variables denoting the features and label data
  169. input_var = C.input_variable((num_channels, image_height, image_width))
  170. label_var = C.input_variable((num_classes))
  171. # Normalize the input
  172. feature_scale = 1.0 / 256.0
  173. input_var_norm = C.element_times(feature_scale, input_var)
  174. z = create_basic_model(input_var_norm, 2)
  175. # loss and metric
  176. ce = C.cross_entropy_with_softmax(z, label_var)
  177. pe = C.classification_error(z, label_var)
  178. # training config
  179. epoch_size = EPOCH_SIZE
  180. minibatch_size = MINIBATCH_SIZE
  181. # Set training parameters
  182. lr_per_minibatch = LEARNING_RATIO
  183. momentums = C.momentum_schedule(0.9, minibatch_size=minibatch_size)
  184. l2_reg_weight = L2_REG
  185. # trainer object
  186. learner = C.momentum_sgd(z.parameters,
  187. lr=lr_per_minibatch,
  188. momentum=momentums,
  189. l2_regularization_weight=l2_reg_weight)
  190. progress_printer = C.logging.ProgressPrinter(tag='Training', num_epochs=max_epochs)
  191. trainer = C.Trainer(z, (ce, pe), [learner], [progress_printer])
  192. # define mapping from reader streams to network inputs
  193. input_map = {
  194. input_var: reader_train.streams.features,
  195. label_var: reader_train.streams.labels
  196. }
  197. #存储中间checkpoint
  198. checkpoint = CHECKPOINT#'checkpoint/checkpointed_17.tmp'
  199. checkpoint_splited = checkpoint.split('.')
  200. if os.path.exists(checkpoint):
  201. print("Trying to restore from checkpoint")
  202. mb_source_state = trainer.restore_from_checkpoint(checkpoint)
  203. reader_train.restore_from_checkpoint(mb_source_state)
  204. print("Restore has finished successfully")
  205. else:
  206. print("No restore file found")
  207. checkpoint_frequency = epoch_size * 4
  208. last_checkpoint = 0
  209. C.logging.log_number_of_parameters(z)
  210. print() # 将训练参数输出
  211. # perform model training
  212. batch_index = 0
  213. plot_data = {'batchindex': [], 'loss': [], 'error': []} # 为可视化做准备
  214. plot_data_validate = {'batchindex': [], 'error': []} # 验证集效果
  215. plot_data_test = {'batchindex': [], 'error': []} # test集效果
  216. plot_data_DV = {'batchindex': [], 'error_DV': []} # 两者差值效果
  217. for epoch in range(max_epochs): # loop over epochs
  218. sample_count = 0
  219. while sample_count < epoch_size: # loop over minibatches in the epoch
  220. data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), # 问题在这
  221. input_map=input_map)
  222. trainer.train_minibatch(data) # update model with it
  223. sample_count += data[label_var].num_samples # count samples processed so far
  224. #保存中间的checkpoint
  225. if int(trainer.total_number_of_samples_seen / checkpoint_frequency) != last_checkpoint:
  226. mb_source_state = reader_train.get_checkpoint_state()
  227. trainer.save_checkpoint(
  228. checkpoint_splited[0] + '_' + str(last_checkpoint) + '.' + checkpoint_splited[1], mb_source_state)
  229. last_checkpoint = int(trainer.total_number_of_samples_seen / checkpoint_frequency)
  230. # For visualization...
  231. plot_data['batchindex'].append(batch_index)
  232. plot_data['loss'].append(trainer.previous_minibatch_loss_average)
  233. plot_data['error'].append(
  234. trainer.previous_minibatch_evaluation_average) # minibatch的平均准确性,每一个元素都是minnibatch的精度均值
  235. minibatch_evaluation_average = trainer.previous_minibatch_evaluation_average # 0-1
  236. # 每300个index进行一次验证集的测试
  237. if batch_index % 300 == 0:
  238. error_v,error_t = ValidateVisual(reader_validate, reader_test, trainer, minibatch_evaluation_average, 3000, input_var,
  239. label_var,batch_index, plot_data_validate, plot_data_DV, plot_data_test)
  240. print("trainer.previous_minibatch_evaluation_average is " + str(
  241. minibatch_evaluation_average * 100) + "%") # 打印出,看其类型等
  242. batch_index += 1
  243. # 判断是否满足条件
  244. if (error_v < 0.03) & (error_t < 0.02) & epoch>150:
  245. is_break=True
  246. global is_break #恢复全局变量
  247. break
  248. trainer.summarize_training_progress() # 输出阶段性训练结果
  249. if is_break:
  250. break
  251. VisualValidate(plot_data, plot_data_validate, plot_data_DV, plot_data_test)
  252. return C.softmax(z) # tuple没有save方法
  253. def TransImgToData(trained_model, img_txt, PATH_1,PATH_2):
  254. """
  255. 对特定数据集进行推理统计,查看最终误差
  256. :param trained_model: 加载好的模型
  257. :param img_txt: 测试数据集
  258. :param PATH: 目标文件夹
  259. :return: None
  260. """
  261. with open(img_txt, 'r')as img_file:
  262. img_lines = img_file.readlines()
  263. img_path = []
  264. img_label = []
  265. metric_count = 0
  266. for index in range(len(img_lines)):
  267. img_path.append(img_lines[index].split('\t')[0])
  268. img_label.append(img_lines[index].split('\t')[1])
  269. orig_img = cv2.imdecode(np.fromfile(img_path[index], dtype=np.uint8), -1) # 兼容中文
  270. img_b=orig_img[:,:,0]
  271. img_g=orig_img[:,:,1]
  272. img_r=orig_img[:,:,2]
  273. orig_img = cv2.merge([img_b, img_g, img_r])
  274. # orig_img=cv2.imread(img_path[index]) #读入单张图像
  275. img_resized = cv2.resize(orig_img, (96,96), interpolation=cv2.INTER_LINEAR)
  276. model_arg_rep = np.ascontiguousarray(np.array(img_resized, dtype=np.float32).transpose(2, 0, 1))
  277. arguments = {trained_model.arguments[0]: [model_arg_rep]}
  278. start_time = time.time()
  279. output = trained_model.eval(arguments)
  280. end_time = time.time()
  281. print("Totaly time is " + str(end_time - start_time))
  282. output = output.tolist()
  283. pred_label = output[0].index(max(output[0]))
  284. if pred_label != int(img_label[index].strip()): # 如果model判断错误
  285. metric_count += 1
  286. print(img_path[index])
  287. path, name = os.path.split(img_path[index]) # 很方便,直接将路径和文件名分开
  288. newpath = os.path.join(PATH_2, name)
  289. shutil.copyfile(img_path[index], newpath) # 将file复制到newpath
  290. # else: #model判断正确
  291. # path, name = os.path.split(img_path[index])
  292. # newpath = os.path.join(PATH_1, name)
  293. # shutil.copyfile(img_path[index], newpath) # 将file复制到newpath
  294. metric = metric_count / len(img_lines) * 100
  295. print("total count is "+str(len(img_lines))+".")
  296. print("The metric is " + str(metric) + "%")
  297. return metric
  298. def CaculateMetric(txt_list,trained_model,PATH_1,PATH_2):
  299. """
  300. 批量调用TransImgToData函数
  301. :param txt_list:
  302. :param trained_model:
  303. :param PATH_1:
  304. :param PATH_2:
  305. :return:
  306. """
  307. total_txt_count=len(txt_list)
  308. metric_dict=OrderedDict()
  309. for index in range(total_txt_count):
  310. metric=TransImgToData(trained_model, txt_list[index], PATH_1,PATH_2)
  311. metric_dict[txt_list[index]] = metric
  312. return metric_dict
  313. """在model上进行推理"""
  314. #
  315. if __name__ == '__main__':
  316. if os.path.exists("cls_pred_uncutted_resnet_"+str(INDEX)+".model"):
  317. model_path = r"cls_pred_uncutted_resnet_"+str(INDEX)+".model"
  318. txt_list = ['new_breast.txt','new_dajiao.txt','new_linbajie.txt','new_Abd.txt','new_Cav.txt','new_heart.txt',
  319. 'new_jindongmai.txt','new_jingzhi.txt','new_natural_img.txt','new_real_img.txt','newForTest.txt',
  320. 'new_report.txt','new_Thy.txt','new_Thy_xingtai2.txt','new_test_brightness.txt',r'20190610\new_Breast_class1.txt',
  321. r'20190610\new_Breast_class2.txt','new_Breast20190604.txt','new_Thy20190604.txt','temp.txt']
  322. # txt_list =['new_breast.txt']
  323. PATH_1 = r'D:\JosephProjects\20190528乳腺分类model最终测试\Image_Insepction\correct'
  324. PATH_2 = r"D:\JosephProjects\20190528乳腺分类model最终测试\Image_Insepction\incorrect"
  325. print("test model...")
  326. # C.device.try_set_default_device(C.device.cpu()) # 在测试的时候选择cpu
  327. trained_model = C.Function.load(model_path)
  328. metric_dict = CaculateMetric(txt_list,trained_model,PATH_1,PATH_2)
  329. print(metric_dict)
  330. print("Done!")
  331. else:
  332. print("Let's training model...")
  333. reader_train = create_reader(r'20190610\train_new0606.txt',0, True)
  334. reader_validate = create_reader(r'20190610\validate_new0606.txt',0, False)
  335. reader_test = create_reader(r'20190610\new_Breast_class2.txt',0, False)
  336. pred = train_and_evaluate(reader_train, reader_validate, reader_test,
  337. max_epochs=200, create_basic_model=resnet_12)
  338. pred.save("cls_pred_uncutted_resnet_"+str(INDEX)+".model", format=C.ModelFormat.CNTKv2)
  339. print("Done!")