TrainSdk.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import http.client
  2. import shutil
  3. import os
  4. # 获取训练文件夹中的文件数量
  5. # token:训练用的授权码
  6. def get_file_count(token):
  7. uri = "localhost:8888"
  8. conn = http.client.HTTPConnection(uri)
  9. conn.request("GET", "/GetFolderFileCount?token=" + token)
  10. res = conn.getresponse()
  11. if res.status == 200:
  12. count_data = res.read(4)
  13. #print(count_data)
  14. count = int.from_bytes(count_data, byteorder='little', signed=True)
  15. return count
  16. else:
  17. reason = str(res.reason)
  18. if len(reason) > 0:
  19. reason = ", reason: " + reason
  20. raise Exception("Get file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  21. # 获取文件和其对应的标注数据,名称,是否是视频
  22. # token:训练用的授权码
  23. # index:文件序号
  24. def get_labeled_file(token, index):
  25. uri = "localhost:8888"
  26. conn = http.client.HTTPConnection(uri)
  27. conn.request("GET", "/GetTrainFile?token=" + token + "&index=" + str(index))
  28. res = conn.getresponse()
  29. if res.status == 200:
  30. file_size_data = res.read(4)
  31. file_size = int.from_bytes(file_size_data, byteorder='little', signed=True)
  32. file_data = res.read(file_size)
  33. label_size_data = res.read(4)
  34. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  35. label_data = res.read(label_size).decode("utf-8")
  36. file_name_size_data = res.read(4)
  37. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  38. file_name = res.read(file_name_size).decode("utf-8")
  39. file_isVideo_size_data = res.read(4)
  40. file_isVideo_size = int.from_bytes(file_isVideo_size_data, byteorder='little', signed=True)
  41. file_isVideo = res.read(file_isVideo_size).decode()
  42. return file_data, label_data, file_name, file_isVideo
  43. else:
  44. reason = str(res.reason)
  45. if len(reason) > 0:
  46. reason = ", reason: " + reason
  47. raise Exception(
  48. "Get file and label failed, could be url error or token error. Code: " + str(res.status) + reason)
  49. # 将训练结果的模型保存到系统可识别的路径下
  50. # trainedFile: 要保存的训练结果
  51. def save_output_model(trainedFile: object) -> object:
  52. uri = "localhost:8888"
  53. conn = http.client.HTTPConnection(uri)
  54. conn.request("GET", "/GetModelOutputFolder")
  55. res = conn.getresponse()
  56. if res.status == 200:
  57. outputFolder = res.readlines()[0].decode("utf-8")
  58. sourceFolder, sourceFileName = os.path.split(trainedFile)
  59. outputFile = os.path.join(outputFolder, sourceFileName)
  60. shutil.copy(trainedFile, outputFile)
  61. else:
  62. reason = str(res.reason)
  63. if len(reason) > 0:
  64. reason = ", reason: " + reason
  65. raise Exception(
  66. "Save output model failed, could be url error or token error. Code: " + str(res.status) + reason)
  67. # 获取测试文件夹中的文件数量
  68. # token:授权码
  69. def get_test_file_count(token):
  70. uri = "localhost:8888"
  71. conn = http.client.HTTPConnection(uri)
  72. conn.request("GET", "/GetTestFolderFileCount?token=" + token)
  73. res = conn.getresponse()
  74. if res.status == 200:
  75. count_data = res.read(4)
  76. #print(count_data)
  77. count = int.from_bytes(count_data, byteorder='little', signed=True)
  78. return count
  79. else:
  80. reason = str(res.reason)
  81. if len(reason) > 0:
  82. reason = ", reason: " + reason
  83. raise Exception(
  84. "Get test file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  85. # 获取测试文件和其对应的标注数据,名称,是否是视频
  86. # token:授权码
  87. # index:文件序号
  88. def get_test_labeled_file(token, index):
  89. uri = "localhost:8888"
  90. conn = http.client.HTTPConnection(uri)
  91. conn.request("GET", "/GetTestFile?token=" + token + "&index=" + str(index))
  92. res = conn.getresponse()
  93. if res.status == 200:
  94. file_size_data = res.read(4)
  95. file_size = int.from_bytes(file_size_data, byteorder='little', signed=True)
  96. file_data = res.read(file_size)
  97. label_size_data = res.read(4)
  98. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  99. label_data = res.read(label_size).decode("utf-8")
  100. file_name_size_data = res.read(4)
  101. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  102. file_name = res.read(file_name_size).decode("utf-8")
  103. file_isVideo_size_data = res.read(4)
  104. file_isVideo_size = int.from_bytes(file_isVideo_size_data, byteorder='little', signed=True)
  105. file_isVideo = res.read(file_isVideo_size).decode()
  106. return file_data, label_data, file_name, file_isVideo
  107. else:
  108. reason = str(res.reason)
  109. if len(reason) > 0:
  110. reason = ", reason: " + reason
  111. raise Exception(
  112. "Get test file and label failed, could be url error or token error. Code: " + str(res.status) + reason)