split-data.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # coding=utf-8
  2. import os, random, shutil
  3. # 将图片拆分成训练集train(0.8)和验证集val(0.2)
  4. def moveFile(Dir, train_ratio=0.8, val_ratio=0.2):
  5. if not os.path.exists(os.path.join(Dir, 'train')):
  6. os.makedirs(os.path.join(Dir, 'train'))
  7. if not os.path.exists(os.path.join(Dir, 'val')):
  8. os.makedirs(os.path.join(Dir, 'val'))
  9. filenames = []
  10. for root, dirs, files in os.walk(Dir):
  11. for name in files:
  12. filenames.append(name)
  13. break
  14. filenum = len(filenames)
  15. num_train = int(filenum * train_ratio)
  16. sample_train = random.sample(filenames, num_train)
  17. for name in sample_train:
  18. shutil.move(os.path.join(Dir, name), os.path.join(Dir, 'train'))
  19. sample_val = list(set(filenames).difference(set(sample_train)))
  20. for name in sample_val:
  21. shutil.move(os.path.join(Dir, name), os.path.join(Dir, 'val'))
  22. def remove_batch_file():
  23. file_path = r'D:\val'
  24. while(True):
  25. try:
  26. shutil.rmtree(file_path)
  27. except:
  28. if os.path.exists(file_path):
  29. continue
  30. else:
  31. break
  32. if __name__ == '__main__':
  33. Dir = r"F:\data"
  34. for root, dirs, files in os.walk(Dir):
  35. for name in dirs:
  36. folder = os.path.join(root, name)
  37. print("正在处理:" + folder)
  38. moveFile(folder)
  39. print("处理完成")
  40. break
  41. # Dir = r"D:\codes\pytorch_heart_classification\data_j"
  42. # dst_path = "D:\\codes\\pytorch_heart_classification"
  43. #
  44. # count = 0
  45. # class_folder = os.listdir(Dir)
  46. # for child in class_folder:
  47. # name = os.listdir(os.path.join(Dir, child))
  48. # for class_name in name:
  49. # for img in os.listdir(os.path.join(os.path.join(Dir, child), class_name)):
  50. # count += 1
  51. # print("第:{}张图片".format(count))
  52. #
  53. # if class_name == "train":
  54. # for img in os.listdir(os.path.join(os.path.join(Dir, child), class_name)):
  55. # if not os.path.exists(os.path.join(dst_path, class_name)):
  56. # os.mkdir(os.path.join(dst_path, class_name))
  57. #
  58. # if not os.path.exists(os.path.join(os.path.join(dst_path, class_name), child)):
  59. # os.mkdir(os.path.join(os.path.join(dst_path, class_name), child))
  60. #
  61. # shutil.copy(os.path.join(os.path.join(os.path.join(os.path.join(Dir, child)), class_name),img),
  62. # os.path.join(os.path.join(dst_path, class_name), child))
  63. #
  64. # if class_name == "val":
  65. # for img in os.listdir(os.path.join(os.path.join(Dir, child), class_name)):
  66. # if not os.path.exists(os.path.join(dst_path, class_name)):
  67. # os.mkdir(os.path.join(dst_path, class_name))
  68. #
  69. # if not os.path.exists(os.path.join(os.path.join(dst_path, class_name), child)):
  70. # os.mkdir(os.path.join(os.path.join(dst_path, class_name), child))
  71. #
  72. # shutil.copy(os.path.join(os.path.join(os.path.join(os.path.join(Dir, child)), class_name),img),
  73. # os.path.join(os.path.join(dst_path, class_name), child))
  74. #