yolo2coco.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import os
  2. import cv2
  3. import json
  4. from tqdm import tqdm
  5. from sklearn.model_selection import train_test_split
  6. import argparse
  7. # visdrone2019
  8. classes = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument('--image_path', default='',type=str, help="path of images")
  11. parser.add_argument('--label_path', default='',type=str, help="path of labels .txt")
  12. parser.add_argument('--save_path', default='data.json', type=str, help="if not split the dataset, give a path to a json file")
  13. arg = parser.parse_args()
  14. def yolo2coco(arg):
  15. print("Loading data from ", arg.image_path, arg.label_path)
  16. assert os.path.exists(arg.image_path)
  17. assert os.path.exists(arg.label_path)
  18. originImagesDir = arg.image_path
  19. originLabelsDir = arg.label_path
  20. # images dir name
  21. indexes = os.listdir(originImagesDir)
  22. dataset = {'categories': [], 'annotations': [], 'images': []}
  23. for i, cls in enumerate(classes, 0):
  24. dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
  25. # 标注的id
  26. ann_id_cnt = 0
  27. for k, index in enumerate(tqdm(indexes)):
  28. # 支持 png jpg 格式的图片.
  29. txtFile = f'{index[:index.rfind(".")]}.txt'
  30. stem = index[:index.rfind(".")]
  31. # 读取图像的宽和高
  32. try:
  33. im = cv2.imread(os.path.join(originImagesDir, index))
  34. height, width, _ = im.shape
  35. except Exception as e:
  36. print(f'{os.path.join(originImagesDir, index)} read error.\nerror:{e}')
  37. # 添加图像的信息
  38. if not os.path.exists(os.path.join(originLabelsDir, txtFile)):
  39. # 如没标签,跳过,只保留图片信息.
  40. continue
  41. dataset['images'].append({'file_name': index,
  42. 'id': stem,
  43. 'width': width,
  44. 'height': height})
  45. with open(os.path.join(originLabelsDir, txtFile), 'r') as fr:
  46. labelList = fr.readlines()
  47. for label in labelList:
  48. label = label.strip().split()
  49. x = float(label[1])
  50. y = float(label[2])
  51. w = float(label[3])
  52. h = float(label[4])
  53. # convert x,y,w,h to x1,y1,x2,y2
  54. H, W, _ = im.shape
  55. x1 = (x - w / 2) * W
  56. y1 = (y - h / 2) * H
  57. x2 = (x + w / 2) * W
  58. y2 = (y + h / 2) * H
  59. # 标签序号从0开始计算, coco2017数据集标号混乱,不管它了。
  60. cls_id = int(label[0])
  61. width = max(0, x2 - x1)
  62. height = max(0, y2 - y1)
  63. dataset['annotations'].append({
  64. 'area': width * height,
  65. 'bbox': [x1, y1, width, height],
  66. 'category_id': cls_id,
  67. 'id': ann_id_cnt,
  68. 'image_id': stem,
  69. 'iscrowd': 0,
  70. # mask, 矩形是从左上角点按顺时针的四个顶点
  71. 'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]]
  72. })
  73. ann_id_cnt += 1
  74. # 保存结果
  75. with open(arg.save_path, 'w') as f:
  76. json.dump(dataset, f)
  77. print('Save annotation to {}'.format(arg.save_path))
  78. if __name__ == "__main__":
  79. yolo2coco(arg)