export.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2022/10/12 10:55
  3. # @Author : Marvin.yuan
  4. # @File : export.py
  5. # @Description : Export the model to .onnx file.
  6. from pathlib import Path
  7. import onnx
  8. import torch.onnx
  9. def export_onnx(model,
  10. ckpt,
  11. inp,
  12. exp,
  13. input_names=['input_1'],
  14. output_names=['output_1'],
  15. opset=12,
  16. simplify=True):
  17. """
  18. VINNO ONNX export
  19. Args:
  20. model: Torch model.
  21. ckpt: Checkpoint path.
  22. inp: Input tensor.
  23. exp: Export path.
  24. input_names: Names to assign to the input nodes of the graph, in order. Default: ['input_1'].
  25. output_names: Names to assign to the output nodes of the graph, in order. Default: ['output_1'].
  26. opset: Opset version. Default: 12.
  27. simplify: Flag to simplify onnx model. Default: True.
  28. Returns:
  29. model_onnx: Onnx model.
  30. """
  31. assert ckpt.endswith(".pth"), f"Passed ckpt_path: {ckpt} is not pt file!"
  32. device = next(model.parameters()).device
  33. torch.set_flush_denormal(True)
  34. model = model.to(device)
  35. ckpt = torch.load(ckpt, map_location=device)
  36. csd = ckpt.get("ema") or ckpt["model"]
  37. model.load_state_dict(csd)
  38. model.eval()
  39. f = Path(exp).with_suffix(".onnx")
  40. torch.onnx.export(
  41. model.to(device),
  42. inp.to(device),
  43. str(f),
  44. input_names=input_names,
  45. output_names=output_names,
  46. opset_version=opset,
  47. verbose=True)
  48. # Checks
  49. model_onnx = onnx.load(str(f)) # load onnx model
  50. onnx.checker.check_model(model_onnx) # check
  51. if simplify:
  52. try:
  53. import onnxsim
  54. model_onnx, check = onnxsim.simplify(model_onnx)
  55. assert check, 'assert check failed'
  56. onnx.save(model_onnx, str(f))
  57. except Exception as e:
  58. print(e)
  59. return model_onnx
  60. if __name__ == "__main__":
  61. from utils.config import Config
  62. cfg = Config(path='model.yaml')
  63. model = cfg.model
  64. model.forward = model.export
  65. ckpt_path = "best.pth"
  66. dummy_input = torch.randn(1, 3, 256, 256)
  67. export_onnx(
  68. model=model,
  69. ckpt=ckpt_path,
  70. inp=dummy_input,
  71. exp=Path(ckpt_path).with_suffix(".onnx"),
  72. input_names=['input_1'],
  73. output_names=['output_1'],
  74. opset=12,
  75. simplify=True)