TrainSdk.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import http.client
  2. import shutil
  3. import os
  4. #获取训练文件夹中的文件数量
  5. #token:训练用的授权码
  6. def get_file_count(token):
  7. uri = "localhost:8888"
  8. conn = http.client.HTTPConnection(uri)
  9. conn.request("GET", "/GetFolderFileCount?token=" + token)
  10. res = conn.getresponse()
  11. if res.status == 200:
  12. count_data = res.read(4)
  13. #print(count_data)
  14. count = int.from_bytes(count_data, byteorder='little', signed=True)
  15. return count
  16. else:
  17. reason = str(res.reason)
  18. if len(reason) > 0 :
  19. reason = ", reason: " + reason
  20. raise Exception("Get file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  21. #获取文件和其对应的标注数据,名称,是否是视频
  22. #token:训练用的授权码
  23. #index:文件序号
  24. def get_labeled_file(token, index):
  25. uri = "localhost:8888"
  26. conn = http.client.HTTPConnection(uri)
  27. conn.request("GET", "/GetTrainFile?token=" + token + "&index=" + str(index))
  28. res = conn.getresponse()
  29. if res.status == 200:
  30. file_size_data = res.read(4)
  31. file_size = int.from_bytes(file_size_data, byteorder='little', signed=True)
  32. file_data = res.read(file_size)
  33. label_size_data = res.read(4)
  34. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  35. label_data = res.read(label_size).decode("utf-8")
  36. file_name_size_data = res.read(4)
  37. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  38. file_name = res.read(file_name_size).decode("utf-8")
  39. file_isVideo_size_data = res.read(4)
  40. file_isVideo_size = int.from_bytes(file_isVideo_size_data, byteorder='little', signed=True)
  41. file_isVideo = res.read(file_isVideo_size).decode()
  42. return file_data, label_data, file_name, file_isVideo
  43. else:
  44. reason = str(res.reason)
  45. if len(reason) > 0 :
  46. reason = ", reason: " + reason
  47. raise Exception("Get file and label failed, could be url error or token error. Code: " + str(res.status) + reason)
  48. #将训练结果的模型保存到系统可识别的路径下
  49. #trainedFile: 要保存的训练结果
  50. def save_output_model(trainedFile: object) -> object:
  51. uri = "localhost:8888"
  52. conn = http.client.HTTPConnection(uri)
  53. conn.request("GET", "/GetModelOutputFolder")
  54. res = conn.getresponse()
  55. if res.status == 200:
  56. outputFolder = res.readlines()[0].decode("utf-8")
  57. sourceFolder, sourceFileName = os.path.split(trainedFile)
  58. outputFile = os.path.join(outputFolder,sourceFileName)
  59. shutil.copy(trainedFile,outputFile)
  60. else:
  61. reason = str(res.reason)
  62. if len(reason) > 0 :
  63. reason = ", reason: " + reason
  64. raise Exception("Save output model failed, could be url error or token error. Code: " + str(res.status) + reason)
  65. #获取测试文件夹中的文件数量
  66. #token:授权码
  67. def get_test_file_count(token):
  68. uri = "localhost:8888"
  69. conn = http.client.HTTPConnection(uri)
  70. conn.request("GET", "/GetTestFolderFileCount?token=" + token)
  71. res = conn.getresponse()
  72. if res.status == 200:
  73. count_data = res.read(4)
  74. #print(count_data)
  75. count = int.from_bytes(count_data, byteorder='little', signed=True)
  76. return count
  77. else:
  78. reason = str(res.reason)
  79. if len(reason) > 0 :
  80. reason = ", reason: " + reason
  81. raise Exception("Get test file count failed, could be url error or token error. Code: " + str(res.status) + reason)
  82. #获取测试文件和其对应的标注数据,名称,是否是视频
  83. #token:授权码
  84. #index:文件序号
  85. def get_test_labeled_file(token, index):
  86. uri = "localhost:8888"
  87. conn = http.client.HTTPConnection(uri)
  88. conn.request("GET", "/GetTestFile?token=" + token + "&index=" + str(index))
  89. res = conn.getresponse()
  90. if res.status == 200:
  91. file_size_data = res.read(4)
  92. file_size = int.from_bytes(file_size_data, byteorder='little', signed=True)
  93. file_data = res.read(file_size)
  94. label_size_data = res.read(4)
  95. label_size = int.from_bytes(label_size_data, byteorder='little', signed=True)
  96. label_data = res.read(label_size).decode("utf-8")
  97. file_name_size_data = res.read(4)
  98. file_name_size = int.from_bytes(file_name_size_data, byteorder='little', signed=True)
  99. file_name = res.read(file_name_size).decode("utf-8")
  100. file_isVideo_size_data = res.read(4)
  101. file_isVideo_size = int.from_bytes(file_isVideo_size_data, byteorder='little', signed=True)
  102. file_isVideo = res.read(file_isVideo_size).decode()
  103. return file_data, label_data, file_name, file_isVideo
  104. else:
  105. reason = str(res.reason)
  106. if len(reason) > 0 :
  107. reason = ", reason: " + reason
  108. raise Exception("Get test file and label failed, could be url error or token error. Code: " + str(res.status) + reason)