demo.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding: utf-8 -*-
  2. '''
  3. @Time : 20/04/25 15:49
  4. @Author : huguanghao
  5. @File : demo.py
  6. @Noice :
  7. @Modificattion :
  8. @Author :
  9. @Time :
  10. @Detail :
  11. '''
  12. # import sys
  13. # import time
  14. # from PIL import Image, ImageDraw
  15. # from models.tiny_yolo import TinyYoloNet
  16. from tool.utils import *
  17. from tool.torch_utils import *
  18. from tool.darknet2pytorch import Darknet
  19. import argparse
  20. """hyper parameters"""
  21. use_cuda = True
  22. def detect_cv2(cfgfile, weightfile, imgfile):
  23. import cv2
  24. m = Darknet(cfgfile)
  25. m.print_network()
  26. m.load_weights(weightfile)
  27. print('Loading weights from %s... Done!' % (weightfile))
  28. if use_cuda:
  29. m.cuda()
  30. num_classes = m.num_classes
  31. if num_classes == 20:
  32. namesfile = 'data/voc.names'
  33. elif num_classes == 80:
  34. namesfile = 'data/coco.names'
  35. else:
  36. namesfile = 'data/x.names'
  37. class_names = load_class_names(namesfile)
  38. img = cv2.imread(imgfile)
  39. sized = cv2.resize(img, (m.width, m.height))
  40. sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB)
  41. for i in range(2):
  42. start = time.time()
  43. boxes = do_detect(m, sized, 0.4, 0.6, use_cuda)
  44. finish = time.time()
  45. if i == 1:
  46. print('%s: Predicted in %f seconds.' % (imgfile, (finish - start)))
  47. plot_boxes_cv2(img, boxes[0], savename='predictions.jpg', class_names=class_names)
  48. def detect_cv2_camera(cfgfile, weightfile):
  49. import cv2
  50. m = Darknet(cfgfile)
  51. m.print_network()
  52. m.load_weights(weightfile)
  53. print('Loading weights from %s... Done!' % (weightfile))
  54. if use_cuda:
  55. m.cuda()
  56. cap = cv2.VideoCapture(0)
  57. # cap = cv2.VideoCapture("./test.mp4")
  58. cap.set(3, 1280)
  59. cap.set(4, 720)
  60. print("Starting the YOLO loop...")
  61. num_classes = m.num_classes
  62. if num_classes == 20:
  63. namesfile = 'data/voc.names'
  64. elif num_classes == 80:
  65. namesfile = 'data/coco.names'
  66. else:
  67. namesfile = 'data/x.names'
  68. class_names = load_class_names(namesfile)
  69. while True:
  70. ret, img = cap.read()
  71. sized = cv2.resize(img, (m.width, m.height))
  72. sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB)
  73. start = time.time()
  74. boxes = do_detect(m, sized, 0.4, 0.6, use_cuda)
  75. finish = time.time()
  76. print('Predicted in %f seconds.' % (finish - start))
  77. result_img = plot_boxes_cv2(img, boxes[0], savename=None, class_names=class_names)
  78. cv2.imshow('Yolo demo', result_img)
  79. cv2.waitKey(1)
  80. cap.release()
  81. def detect_skimage(cfgfile, weightfile, imgfile):
  82. from skimage import io
  83. from skimage.transform import resize
  84. m = Darknet(cfgfile)
  85. m.print_network()
  86. m.load_weights(weightfile)
  87. print('Loading weights from %s... Done!' % (weightfile))
  88. if use_cuda:
  89. m.cuda()
  90. num_classes = m.num_classes
  91. if num_classes == 20:
  92. namesfile = 'data/voc.names'
  93. elif num_classes == 80:
  94. namesfile = 'data/coco.names'
  95. else:
  96. namesfile = 'data/x.names'
  97. class_names = load_class_names(namesfile)
  98. img = io.imread(imgfile)
  99. sized = resize(img, (m.width, m.height)) * 255
  100. for i in range(2):
  101. start = time.time()
  102. boxes = do_detect(m, sized, 0.4, 0.4, use_cuda)
  103. finish = time.time()
  104. if i == 1:
  105. print('%s: Predicted in %f seconds.' % (imgfile, (finish - start)))
  106. plot_boxes_cv2(img, boxes, savename='predictions.jpg', class_names=class_names)
  107. def get_args():
  108. parser = argparse.ArgumentParser('Test your image or video by trained model.')
  109. parser.add_argument('-cfgfile', type=str, default='./cfg/yolov4.cfg',
  110. help='path of cfg file', dest='cfgfile')
  111. parser.add_argument('-weightfile', type=str,
  112. default='./checkpoints/Yolov4_epoch1.pth',
  113. help='path of trained model.', dest='weightfile')
  114. parser.add_argument('-imgfile', type=str,
  115. default='./data/mscoco2017/train2017/190109_180343_00154162.jpg',
  116. help='path of your image file.', dest='imgfile')
  117. args = parser.parse_args()
  118. return args
  119. if __name__ == '__main__':
  120. args = get_args()
  121. if args.imgfile:
  122. detect_cv2(args.cfgfile, args.weightfile, args.imgfile)
  123. # detect_imges(args.cfgfile, args.weightfile)
  124. # detect_cv2(args.cfgfile, args.weightfile, args.imgfile)
  125. # detect_skimage(args.cfgfile, args.weightfile, args.imgfile)
  126. else:
  127. detect_cv2_camera(args.cfgfile, args.weightfile)