123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- from __future__ import print_function # Use a function definition from future version (say 3.x from 2.7 interpreter)
- from cntk.initializer import he_normal, normal
- from cntk.layers import AveragePooling, MaxPooling, BatchNormalization, Convolution, Dense
- from cntk.ops import element_times, relu
- from typing import Any, Union, Tuple
- import matplotlib.pyplot as plt
- import numpy as np
- import os
- import cntk as C
- import cv2
- import time
- import shutil
- from resnet_model_12 import conv_bn, conv_bn_relu, resnet_basic, resnet_basic_inc, resnet_basic_stack, resnet_bottleneck, \
- resnet_bottleneck_inc, resnet_bottleneck_stack, resnet_12
- from collections import OrderedDict
- #图片及标签数据
- image_height = 96
- image_width = 96
- num_channels = 3
- num_classes = 2
- EPOCH_SIZE = 78789
- MINIBATCH_SIZE= 70
- INDEX=43
- CHECKPOINT = 'checkpoint/checkpointed.tmp'
- LEARNING_RATIO = C.learning_parameter_schedule([0.005] * 70 + [0.001] * 70 + [0.0001],epoch_size = EPOCH_SIZE)
- L2_REG=0.02
- is_break=False
- import cntk.io.transforms as xforms
- def create_reader(map_file, mean_file, train):
- """
- 根据txt数据准备可供model读取的数据类型
- :param map_file: 存放数据路径的txt
- :param mean_file: 所有图片的均值文件(此处没有用到)
- :param train: bool,是否训练
- :return: 得到batch_source,可通过其得到图片内部数据及标签
- """
- transforms = [] # list
- if train:
- transforms += [xforms.crop(crop_type='randomside', side_ratio=1.0)]
- transforms += [xforms.scale(width=image_width, height=image_height, channels=num_channels),
- #xforms.mean(mean_file)
- ]
- transforms += [
- xforms.color(brightness_radius=0.3, contrast_radius=0.0, saturation_radius=0.0)
- ]
- # deserializer
- features = C.io.StreamDef(field='image', transforms=transforms)
- labels = C.io.StreamDef(field='label', shape=num_classes)
- streamdefs = C.io.StreamDefs(
- features=features, # first column in map file is referred to as 'image'
- labels=labels # and second as 'label'
- )
- imagedeserialized = C.io.ImageDeserializer(map_file, streamdefs)
- batch_source = C.io.MinibatchSource(imagedeserialized)
- return batch_source
- # ValidateVisual(reader_validate, reader_test, trainer, minibatch_evaluation_average, 3000, input_var,
- # label_var,
- # batch_index, plot_data_validate, plot_data_DV, plot_data_test)
- def ValidateVisual(reader_validate, reader_test, trainer, minibatch_evaluation_average, epoch_size, input_var,
- label_var, batch_index, plot_data_validate, plot_data_DV, plot_data_test):
- """
- 用于在训练的epoch中得出实时的测试集上的效果
- :param reader_validate:验证集txt
- :param reader_test:测试机txt
- :param trainer:训练过程中的trainer
- :param minibatch_evaluation_average:一个minibatch的训练集的误差率
- :param epoch_size:一共要测试的图片数目
- :param input_var:传入的占位符,用来存放图像内的数据
- :param label_var:传入的占位符,用来存放图像的标签
- :param batch_index:绘图的横坐标,训练的epoch
- :param plot_data_validate:dict,存放validate数据集的训练效果
- :param plot_data_DV:validate与训练集minibatch效果的差值
- :param plot_data_test:dict,存放test数据集的训练效果
- :return:None
- """
- minibatch_size = 64 #一次送进model训练的数目
- metric_numer = 0 #统计推理错误的数目
- metric_denom = 0 #统计总共的测试数目
- sample_count = 0 #统计目前的测试数目
- input_map_validate = {
- input_var: reader_validate.streams.features,
- label_var: reader_validate.streams.labels
- }
- while sample_count < epoch_size:
- current_minibatch = min(minibatch_size, epoch_size - sample_count)
- # Fetch next test min batch.
- data = reader_validate.next_minibatch(current_minibatch, input_map=input_map_validate)
- # minibatch data to be trained with
- metric_numer += trainer.test_minibatch(data) * current_minibatch
- metric_denom += current_minibatch
- # Keep track of the number of samples processed so far.
- sample_count += data[label_var].num_samples # 对训练过的样本进行统计
- plot_data_validate['batchindex'].append(batch_index)
- plot_data_validate['error'].append(float(metric_numer / metric_denom))
- error_validate = float((metric_numer) / metric_denom)
- error_DV = minibatch_evaluation_average - error_validate
- plot_data_DV['batchindex'].append(batch_index)
- plot_data_DV['error_DV'].append(error_DV)
- print("error_validate is " + str(error_validate * 100) + "%")
- print("error_DV is " + str(error_DV * 100) + "%")
- metric_numer_test = 0
- metric_denom_test = 0
- sample_count_test = 0
- input_map_test = {
- input_var: reader_test.streams.features,
- label_var: reader_test.streams.labels
- }
- while sample_count_test < epoch_size:
- # while sample_count_test < 6:
- current_minibatch = min(minibatch_size, epoch_size - sample_count_test)
- # current_minibatch = 6
- # Fetch next test min batch.
- data_test = reader_test.next_minibatch(current_minibatch, input_map=input_map_test) # 是否是随机截取
- # minibatch data to be trained with
- metric_numer_test += trainer.test_minibatch(data_test) * current_minibatch
- metric_denom_test += current_minibatch
- # Keep track of the number of samples processed so far.
- sample_count_test += data_test[label_var].num_samples # 对训练过的样本进行统计
- plot_data_test['batchindex'].append(batch_index)
- plot_data_test['error'].append(float(metric_numer_test / metric_denom_test))
- error_test = float((metric_numer_test) / metric_denom_test)
- print("error_test is " + str(error_test * 100) + "%") # 打印出
- return error_validate ,error_test
- def VisualValidate(plot_data, plot_data_validate, plot_data_DV, plot_data_test):
- """
- 最后画出整个训练曲线图
- :param plot_data: minibatch个训练集样本的训练曲线
- :param plot_data_validate: 存放validate集曲线
- :param plot_data_DV: plot_data与plot_data_validate的差值
- :param plot_data_test: 存放test集曲线
- :return: None
- """
- plt.subplot(211)
- plt.plot(plot_data["batchindex"], moving_average(plot_data["error"]), 'r--',
- plot_data_validate["batchindex"], moving_average(plot_data_validate["error"]), 'b--',
- plot_data_test["batchindex"], moving_average(plot_data_test["error"]), 'g--')
- plt.xlabel('Minibatch number')
- plt.ylabel('Label Prediction Error')
- plt.title('resnet'+str(INDEX))
- plt.savefig('train_line.jpg')
- plt.show()
- # plt.subplot(212)
- # plt.plot(plot_data_DV["batchindex"], moving_average(plot_data_DV["error_DV"]), 'g--')
- # plt.xlabel('Minibatch number')
- # plt.ylabel('Label Prediction Error_DV')
- # plt.title('DV')
- # plt.show()
- def moving_average(a, w=10):
- """
- 对测试中的误差进行求取平均,以达到曲线平滑的效果
- :param a: 训练曲线存放的误差数据,list
- :param w: 前后w个值求取平均
- :return: 求取均值后的误差数据
- """
- if len(a) < w:
- return a[:]
- return [val if idx < w else sum(a[(idx - w):idx]) / w for idx, val in enumerate(a)]
- def train_and_evaluate(reader_train, reader_validate, reader_test, max_epochs, create_basic_model):
- """
- 主训练程序
- :param reader_train:存放训练集的txt
- :param reader_validate:验证集
- :param reader_test:测试集
- :param max_epochs:迭代次数
- :param create_basic_model:model构造
- :return:训练好的model
- """
- # Input variables denoting the features and label data
- input_var = C.input_variable((num_channels, image_height, image_width))
- label_var = C.input_variable((num_classes))
- # Normalize the input
- feature_scale = 1.0 / 256.0
- input_var_norm = C.element_times(feature_scale, input_var)
- z = create_basic_model(input_var_norm, 2)
- # loss and metric
- ce = C.cross_entropy_with_softmax(z, label_var)
- pe = C.classification_error(z, label_var)
- # training config
- epoch_size = EPOCH_SIZE
- minibatch_size = MINIBATCH_SIZE
- # Set training parameters
- lr_per_minibatch = LEARNING_RATIO
- momentums = C.momentum_schedule(0.9, minibatch_size=minibatch_size)
- l2_reg_weight = L2_REG
- # trainer object
- learner = C.momentum_sgd(z.parameters,
- lr=lr_per_minibatch,
- momentum=momentums,
- l2_regularization_weight=l2_reg_weight)
- progress_printer = C.logging.ProgressPrinter(tag='Training', num_epochs=max_epochs)
- trainer = C.Trainer(z, (ce, pe), [learner], [progress_printer])
- # define mapping from reader streams to network inputs
- input_map = {
- input_var: reader_train.streams.features,
- label_var: reader_train.streams.labels
- }
- #存储中间checkpoint
- checkpoint = CHECKPOINT#'checkpoint/checkpointed_17.tmp'
- checkpoint_splited = checkpoint.split('.')
- if os.path.exists(checkpoint):
- print("Trying to restore from checkpoint")
- mb_source_state = trainer.restore_from_checkpoint(checkpoint)
- reader_train.restore_from_checkpoint(mb_source_state)
- print("Restore has finished successfully")
- else:
- print("No restore file found")
- checkpoint_frequency = epoch_size * 4
- last_checkpoint = 0
- C.logging.log_number_of_parameters(z)
- print() # 将训练参数输出
- # perform model training
- batch_index = 0
- plot_data = {'batchindex': [], 'loss': [], 'error': []} # 为可视化做准备
- plot_data_validate = {'batchindex': [], 'error': []} # 验证集效果
- plot_data_test = {'batchindex': [], 'error': []} # test集效果
- plot_data_DV = {'batchindex': [], 'error_DV': []} # 两者差值效果
- for epoch in range(max_epochs): # loop over epochs
- sample_count = 0
- while sample_count < epoch_size: # loop over minibatches in the epoch
- data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), # 问题在这
- input_map=input_map)
- trainer.train_minibatch(data) # update model with it
- sample_count += data[label_var].num_samples # count samples processed so far
- #保存中间的checkpoint
- if int(trainer.total_number_of_samples_seen / checkpoint_frequency) != last_checkpoint:
- mb_source_state = reader_train.get_checkpoint_state()
- trainer.save_checkpoint(
- checkpoint_splited[0] + '_' + str(last_checkpoint) + '.' + checkpoint_splited[1], mb_source_state)
- last_checkpoint = int(trainer.total_number_of_samples_seen / checkpoint_frequency)
- # For visualization...
- plot_data['batchindex'].append(batch_index)
- plot_data['loss'].append(trainer.previous_minibatch_loss_average)
- plot_data['error'].append(
- trainer.previous_minibatch_evaluation_average) # minibatch的平均准确性,每一个元素都是minnibatch的精度均值
- minibatch_evaluation_average = trainer.previous_minibatch_evaluation_average # 0-1
- # 每300个index进行一次验证集的测试
- if batch_index % 300 == 0:
- error_v,error_t = ValidateVisual(reader_validate, reader_test, trainer, minibatch_evaluation_average, 3000, input_var,
- label_var,batch_index, plot_data_validate, plot_data_DV, plot_data_test)
- print("trainer.previous_minibatch_evaluation_average is " + str(
- minibatch_evaluation_average * 100) + "%") # 打印出,看其类型等
- batch_index += 1
- # 判断是否满足条件
- if (error_v < 0.03) & (error_t < 0.02) & epoch>150:
- is_break=True
- global is_break #恢复全局变量
- break
- trainer.summarize_training_progress() # 输出阶段性训练结果
- if is_break:
- break
- VisualValidate(plot_data, plot_data_validate, plot_data_DV, plot_data_test)
- return C.softmax(z) # tuple没有save方法
- def TransImgToData(trained_model, img_txt, PATH_1,PATH_2):
- """
- 对特定数据集进行推理统计,查看最终误差
- :param trained_model: 加载好的模型
- :param img_txt: 测试数据集
- :param PATH: 目标文件夹
- :return: None
- """
- with open(img_txt, 'r')as img_file:
- img_lines = img_file.readlines()
- img_path = []
- img_label = []
- metric_count = 0
- for index in range(len(img_lines)):
- img_path.append(img_lines[index].split('\t')[0])
- img_label.append(img_lines[index].split('\t')[1])
- orig_img = cv2.imdecode(np.fromfile(img_path[index], dtype=np.uint8), -1) # 兼容中文
- img_b=orig_img[:,:,0]
- img_g=orig_img[:,:,1]
- img_r=orig_img[:,:,2]
- orig_img = cv2.merge([img_b, img_g, img_r])
- # orig_img=cv2.imread(img_path[index]) #读入单张图像
- img_resized = cv2.resize(orig_img, (96,96), interpolation=cv2.INTER_LINEAR)
- model_arg_rep = np.ascontiguousarray(np.array(img_resized, dtype=np.float32).transpose(2, 0, 1))
- arguments = {trained_model.arguments[0]: [model_arg_rep]}
- start_time = time.time()
- output = trained_model.eval(arguments)
- end_time = time.time()
- print("Totaly time is " + str(end_time - start_time))
- output = output.tolist()
- pred_label = output[0].index(max(output[0]))
- if pred_label != int(img_label[index].strip()): # 如果model判断错误
- metric_count += 1
- print(img_path[index])
- path, name = os.path.split(img_path[index]) # 很方便,直接将路径和文件名分开
- newpath = os.path.join(PATH_2, name)
- shutil.copyfile(img_path[index], newpath) # 将file复制到newpath
- # else: #model判断正确
- # path, name = os.path.split(img_path[index])
- # newpath = os.path.join(PATH_1, name)
- # shutil.copyfile(img_path[index], newpath) # 将file复制到newpath
- metric = metric_count / len(img_lines) * 100
- print("total count is "+str(len(img_lines))+".")
- print("The metric is " + str(metric) + "%")
- return metric
- def CaculateMetric(txt_list,trained_model,PATH_1,PATH_2):
- """
- 批量调用TransImgToData函数
- :param txt_list:
- :param trained_model:
- :param PATH_1:
- :param PATH_2:
- :return:
- """
- total_txt_count=len(txt_list)
- metric_dict=OrderedDict()
- for index in range(total_txt_count):
- metric=TransImgToData(trained_model, txt_list[index], PATH_1,PATH_2)
- metric_dict[txt_list[index]] = metric
- return metric_dict
- """在model上进行推理"""
- #
- if __name__ == '__main__':
- if os.path.exists("cls_pred_uncutted_resnet_"+str(INDEX)+".model"):
- model_path = r"cls_pred_uncutted_resnet_"+str(INDEX)+".model"
- txt_list = ['new_breast.txt','new_dajiao.txt','new_linbajie.txt','new_Abd.txt','new_Cav.txt','new_heart.txt',
- 'new_jindongmai.txt','new_jingzhi.txt','new_natural_img.txt','new_real_img.txt','newForTest.txt',
- 'new_report.txt','new_Thy.txt','new_Thy_xingtai2.txt','new_test_brightness.txt',r'20190610\new_Breast_class1.txt',
- r'20190610\new_Breast_class2.txt','new_Breast20190604.txt','new_Thy20190604.txt','temp.txt']
- # txt_list =['new_breast.txt']
- PATH_1 = r'D:\JosephProjects\20190528乳腺分类model最终测试\Image_Insepction\correct'
- PATH_2 = r"D:\JosephProjects\20190528乳腺分类model最终测试\Image_Insepction\incorrect"
- print("test model...")
- # C.device.try_set_default_device(C.device.cpu()) # 在测试的时候选择cpu
- trained_model = C.Function.load(model_path)
- metric_dict = CaculateMetric(txt_list,trained_model,PATH_1,PATH_2)
- print(metric_dict)
- print("Done!")
- else:
- print("Let's training model...")
- reader_train = create_reader(r'20190610\train_new0606.txt',0, True)
- reader_validate = create_reader(r'20190610\validate_new0606.txt',0, False)
- reader_test = create_reader(r'20190610\new_Breast_class2.txt',0, False)
- pred = train_and_evaluate(reader_train, reader_validate, reader_test,
- max_epochs=200, create_basic_model=resnet_12)
- pred.save("cls_pred_uncutted_resnet_"+str(INDEX)+".model", format=C.ModelFormat.CNTKv2)
- print("Done!")
|