demo_pytorch2onnx.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import sys
  2. import onnx
  3. import os
  4. import argparse
  5. import numpy as np
  6. import cv2
  7. import onnxruntime
  8. import torch
  9. from tool.utils import *
  10. from models import Yolov4
  11. def transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W):
  12. model = Yolov4(n_classes=n_classes, inference=True)
  13. pretrained_dict = torch.load(weight_file, map_location=torch.device('cuda'))
  14. model.load_state_dict(pretrained_dict)
  15. input_names = ["input"]
  16. output_names = ['boxes', 'confs']
  17. dynamic = False
  18. if batch_size <= 0:
  19. dynamic = True
  20. if dynamic:
  21. x = torch.randn((1, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True)
  22. onnx_file_name = "yolov4_-1_3_{}_{}_dynamic.onnx".format(IN_IMAGE_H, IN_IMAGE_W)
  23. dynamic_axes = {"input": {0: "batch_size"}, "boxes": {0: "batch_size"}, "confs": {0: "batch_size"}}
  24. # Export the model
  25. print('Export the onnx model ...')
  26. torch.onnx.export(model,
  27. x,
  28. onnx_file_name,
  29. export_params=True,
  30. opset_version=11,
  31. do_constant_folding=True,
  32. input_names=input_names, output_names=output_names,
  33. dynamic_axes=dynamic_axes)
  34. print('Onnx model exporting done')
  35. return onnx_file_name
  36. else:
  37. x = torch.randn((batch_size, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True)
  38. onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, IN_IMAGE_H, IN_IMAGE_W)
  39. # Export the model
  40. print('Export the onnx model ...')
  41. torch.onnx.export(model,
  42. x,
  43. onnx_file_name,
  44. export_params=True,
  45. opset_version=11,
  46. do_constant_folding=True,
  47. input_names=input_names, output_names=output_names,
  48. dynamic_axes=None)
  49. print('Onnx model exporting done')
  50. return onnx_file_name
  51. def main(weight_file, image_path, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W):
  52. if batch_size <= 0:
  53. onnx_path_demo = transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W)
  54. else:
  55. # Transform to onnx as specified batch size
  56. transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W)
  57. # Transform to onnx for demo
  58. onnx_path_demo = transform_to_onnx(weight_file, 1, n_classes, IN_IMAGE_H, IN_IMAGE_W)
  59. session = onnxruntime.InferenceSession(onnx_path_demo)
  60. # session = onnx.load(onnx_path)
  61. print("The model expects input shape: ", session.get_inputs()[0].shape)
  62. # image_src = cv2.imread(image_path)
  63. # detect(session, image_src)
  64. if __name__ == '__main__':
  65. print("Converting to onnx and running demo ...")
  66. if len(sys.argv) == 7:
  67. weight_file = sys.argv[1]
  68. image_path = sys.argv[2]
  69. batch_size = int(sys.argv[3])
  70. n_classes = int(sys.argv[4])
  71. IN_IMAGE_H = int(sys.argv[5])
  72. IN_IMAGE_W = int(sys.argv[6])
  73. main(weight_file, image_path, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W)
  74. else:
  75. print('Please run this way:\n')
  76. print(' python demo_onnx.py <weight_file> <image_path> <batch_size> <n_classes> <IN_IMAGE_H> <IN_IMAGE_W>')