TrainSdk.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import http.client
  2. import shutil
  3. import os
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import json
  8. # 获取训练文件夹中的图片数量
  9. # token:训练用的授权码
  10. def get_image_count(token):
  11. uri = "localhost:8888"
  12. conn = http.client.HTTPConnection(uri)
  13. conn.request("GET", "/GetFolderFileCount?token=" + token)
  14. res = conn.getresponse()
  15. if res.status == 200:
  16. count_data = res.read(4)
  17. #print(count_data)
  18. count = int.from_bytes(count_data, byteorder='little', signed=True)
  19. return count
  20. else:
  21. reason = str(res.reason)
  22. if len(reason) > 0 :
  23. reason = ", reason: " + reason
  24. raise Exception("Get file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  25. # 获取图片和其对应的标注数据
  26. # token:训练用的授权码
  27. # index:图片序号
  28. def get_labeled_image(token, index):
  29. uri = "localhost:8888"
  30. conn = http.client.HTTPConnection(uri)
  31. conn.request("GET", "/GetTrainFile?token=" + token + "&index=" + str(index))
  32. res = conn.getresponse()
  33. if res.status == 200:
  34. image_size_data = res.read(4)
  35. image_size = int.from_bytes(image_size_data, byteorder='little', signed=True)
  36. image_data = res.read(image_size)
  37. label_size_data = res.read(4)
  38. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  39. label_data = res.read(label_size).decode("utf8")
  40. file_name_size_data = res.read(4)
  41. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  42. file_name = res.read(file_name_size).decode("utf8")
  43. return image_data, label_data, file_name
  44. else:
  45. raise Exception("Get image and label failed, could be url error or token error. Code: " + str(res.status))
  46. # 将训练结果的模型保存到系统可识别的路径下
  47. # trainedFile: 要保存的训练结果
  48. def save_output_model(trainedFile):
  49. uri = "localhost:8888"
  50. conn = http.client.HTTPConnection(uri)
  51. conn.request("GET", "/GetModelOutputFolder")
  52. res = conn.getresponse()
  53. if res.status == 200:
  54. outputFolder = res.readlines()[0].decode("utf8")
  55. sourceFolder, sourceFileName = os.path.split(trainedFile)
  56. outputFile = os.path.join(outputFolder, sourceFileName)
  57. shutil.copy(trainedFile, outputFile)
  58. else:
  59. raise Exception("Save output model failed, could be url error or token error. Code: " + str(res.status))
  60. # 获取测试文件夹中的图片数量
  61. # token:授权码
  62. def get_test_image_count(token):
  63. uri = "localhost:8888"
  64. conn = http.client.HTTPConnection(uri)
  65. conn.request("GET", "/GetTestFolderFileCount?token=" + token)
  66. res = conn.getresponse()
  67. if res.status == 200:
  68. count_data = res.read(4)
  69. count = int.from_bytes(count_data, byteorder='little', signed=True)
  70. return count
  71. else:
  72. raise Exception("Get test image count failed, could be url error or token error. Code: " + str(res.status))
  73. # 获取测试图片和其对应的标注数据
  74. # token:授权码
  75. # index:图片序号
  76. def get_test_labeled_image(token, index):
  77. uri = "localhost:8888"
  78. conn = http.client.HTTPConnection(uri)
  79. conn.request("GET", "/GetTestFile?token=" + token + "&index=" + str(index))
  80. res = conn.getresponse()
  81. if res.status == 200:
  82. image_size_data = res.read(4)
  83. image_size = int.from_bytes(image_size_data, byteorder='little', signed=True)
  84. image_data = res.read(image_size)
  85. label_size_data = res.read(4)
  86. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  87. label_data = res.read(label_size).decode("utf8")
  88. file_name_size_data = res.read(4)
  89. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  90. file_name = res.read(file_name_size).decode("utf8")
  91. return image_data, label_data, file_name
  92. else:
  93. return bytearray(), "", "" # 读取数据时有无法读到的现象
  94. # raise Exception("Get test image and label failed, could be url error or token error. Code: " + str(res.status))
  95. def label_preprocess(label):
  96. class_dict = {'BMode':0, 'BModeBlood':1,'Pseudocolor':2, 'PseudocolorBlood':3, 'Spectrogram':4, 'CEUS':5, 'SE':6,'STE':7,'FourDime':8}
  97. try:
  98. txt_info_dict = json.loads(label)
  99. classes = txt_info_dict[0]["FileResultInfos"][0]["LabeledResult"]["ImageResults"][0]["Conclusion"]["Title"]
  100. dst_label = class_dict[classes]
  101. return dst_label
  102. except Exception as e:
  103. print('label data process wrong!,{}'.format(e))
  104. def preprocess(image):
  105. nparr_data = np.frombuffer(image, dtype=np.uint8)
  106. img_data = cv2.imdecode(nparr_data, cv2.IMREAD_GRAYSCALE)
  107. img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
  108. print(img_data)
  109. # img_data = transform_image(img_data)
  110. # print(img_data)
  111. img_data = cv2.resize(img_data, (224, 224))
  112. if len(img_data.shape) == 2:
  113. img_data = np.expand_dims(img_data, axis=2)
  114. if img_data.max() > 1:
  115. img_data = img_data.astype('float32') / 255
  116. img_data = img_data.transpose((2, 0, 1))
  117. img_data = torch.from_numpy(img_data).type(torch.FloatTensor)
  118. return img_data
  119. if __name__ == "__main__":
  120. image_count = get_test_image_count("4925EC4929684AA0ABB0173B03CFC8FF")
  121. for index in range(image_count):
  122. image, label, name = get_test_labeled_image("4925EC4929684AA0ABB0173B03CFC8FF", index)
  123. if label != "":
  124. dst_label = label_preprocess(label)
  125. txt_info_dict = json.loads(label)
  126. # print(index, name, txt_info_dict[0]["LabeledUltrasoundFileId"], dst_label)
  127. else:
  128. print(index)