TrainSdk.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # -*- coding: utf-8 -*-
  2. import http.client
  3. import os
  4. import shutil
  5. # 获取训练文件夹中的图片数量
  6. # token:训练用的授权码
  7. def get_file_count(token):
  8. uri = "localhost:8888"
  9. conn = http.client.HTTPConnection(uri)
  10. conn.request("GET", "/GetFolderFileCount?token=" + token)
  11. res = conn.getresponse()
  12. if res.status == 200:
  13. count_data = res.read(4)
  14. # print(count_data)
  15. count = int.from_bytes(count_data, byteorder='little', signed=True)
  16. return count
  17. else:
  18. reason = str(res.reason)
  19. if len(reason) > 0:
  20. reason = ", reason: " + reason
  21. raise Exception(
  22. "Get file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  23. # 获取图片和其对应的标注数据,图片名称,已标注图像Id
  24. # token:训练用的授权码
  25. # index:图片序号
  26. def get_labeled_file(token, index):
  27. uri = "localhost:8888"
  28. conn = http.client.HTTPConnection(uri)
  29. conn.request("GET", "/GetTrainFile?token=" + token + "&index=" + str(index))
  30. res = conn.getresponse()
  31. if res.status == 200:
  32. file_size_data = res.read(4)
  33. file_size = int.from_bytes(file_size_data, byteorder='little', signed=True)
  34. file_data = res.read(file_size)
  35. label_size_data = res.read(4)
  36. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  37. label_data = res.read(label_size).decode("utf-8")
  38. file_name_size_data = res.read(4)
  39. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  40. file_name = res.read(file_name_size).decode("utf-8")
  41. file_isVideo_size_data = res.read(4)
  42. file_isVideo_size = int.from_bytes(file_isVideo_size_data, byteorder='little', signed=True)
  43. file_isVideo = res.read(file_isVideo_size).decode()
  44. return file_data, label_data, file_name, file_isVideo
  45. else:
  46. reason = str(res.reason)
  47. if len(reason) > 0:
  48. reason = ", reason: " + reason
  49. raise Exception(
  50. "Get file and label failed, could be url error or token error. Code: " + str(res.status) + reason)
  51. # 将训练结果的模型保存到系统可识别的路径下
  52. # trainedFile: 要保存的训练结果
  53. def save_output_model(trainedFile):
  54. uri = "localhost:8888"
  55. conn = http.client.HTTPConnection(uri)
  56. conn.request("GET", "/GetModelOutputFolder")
  57. res = conn.getresponse()
  58. if res.status == 200:
  59. outputFolder = res.readlines()[0].decode("utf-8")
  60. sourceFolder, sourceFileName = os.path.split(trainedFile)
  61. outputFile = os.path.join(outputFolder, sourceFileName)
  62. shutil.copy(trainedFile, outputFile)
  63. else:
  64. reason = str(res.reason)
  65. if len(reason) > 0:
  66. reason = ", reason: " + reason
  67. raise Exception(
  68. "Save output model failed, could be url error or token error. Code: " + str(res.status) + reason)
  69. # 获取测试文件夹中的图片数量
  70. # token:授权码
  71. def get_test_file_count(token):
  72. uri = "localhost:8888"
  73. conn = http.client.HTTPConnection(uri)
  74. conn.request("GET", "/GetTestFolderFileCount?token=" + token)
  75. res = conn.getresponse()
  76. if res.status == 200:
  77. count_data = res.read(4)
  78. # print(count_data)
  79. count = int.from_bytes(count_data, byteorder='little', signed=True)
  80. return count
  81. else:
  82. reason = str(res.reason)
  83. if len(reason) > 0:
  84. reason = ", reason: " + reason
  85. raise Exception("Get test file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  86. # 获取测试图片和其对应的标注数据,图片名称,已标注图像Id
  87. # token:授权码
  88. # index:图片序号
  89. def get_test_labeled_file(token, index):
  90. uri = "localhost:8888"
  91. conn = http.client.HTTPConnection(uri)
  92. conn.request("GET", "/GetTestFile?token=" + token + "&index=" + str(index))
  93. res = conn.getresponse()
  94. if res.status == 200:
  95. file_size_data = res.read(4)
  96. file_size = int.from_bytes(file_size_data, byteorder='little', signed=True)
  97. file_data = res.read(file_size)
  98. label_size_data = res.read(4)
  99. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  100. label_data = res.read(label_size).decode("utf-8")
  101. file_name_size_data = res.read(4)
  102. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  103. file_name = res.read(file_name_size).decode("utf-8")
  104. file_isVideo_size_data = res.read(4)
  105. file_isVideo_size = int.from_bytes(file_isVideo_size_data, byteorder='little', signed=True)
  106. file_isVideo = res.read(file_isVideo_size).decode()
  107. return file_data, label_data, file_name, file_isVideo
  108. else:
  109. reason = str(res.reason)
  110. if len(reason) > 0:
  111. reason = ", reason: " + reason
  112. raise Exception("Get test file and label failed, could be url error or token error. Code: " + str(res.status) + reason)