12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- # -*- coding: utf-8 -*-
- # @Time : 2022/10/12 10:55
- # @Author : Marvin.yuan
- # @File : export.py
- # @Description : Export the model to .onnx file.
- from pathlib import Path
- import onnx
- import torch.onnx
- def export_onnx(model,
- ckpt,
- inp,
- exp,
- input_names=['input_1'],
- output_names=['output_1'],
- opset=12,
- simplify=True):
- """
- VINNO ONNX export
- Args:
- model: Torch model.
- ckpt: Checkpoint path.
- inp: Input tensor.
- exp: Export path.
- input_names: Names to assign to the input nodes of the graph, in order. Default: ['input_1'].
- output_names: Names to assign to the output nodes of the graph, in order. Default: ['output_1'].
- opset: Opset version. Default: 12.
- simplify: Flag to simplify onnx model. Default: True.
- Returns:
- model_onnx: Onnx model.
- """
- assert ckpt.endswith(".pth"), f"Passed ckpt_path: {ckpt} is not pt file!"
- device = next(model.parameters()).device
- torch.set_flush_denormal(True)
- model = model.to(device)
- ckpt = torch.load(ckpt, map_location=device)
- csd = ckpt.get("ema") or ckpt["model"]
- model.load_state_dict(csd)
- model.eval()
- f = Path(exp).with_suffix(".onnx")
- torch.onnx.export(
- model.to(device),
- inp.to(device),
- str(f),
- input_names=input_names,
- output_names=output_names,
- opset_version=opset,
- verbose=True)
- # Checks
- model_onnx = onnx.load(str(f)) # load onnx model
- onnx.checker.check_model(model_onnx) # check
- if simplify:
- try:
- import onnxsim
- model_onnx, check = onnxsim.simplify(model_onnx)
- assert check, 'assert check failed'
- onnx.save(model_onnx, str(f))
- except Exception as e:
- print(e)
- return model_onnx
- if __name__ == "__main__":
- from utils.config import Config
- cfg = Config(path='model.yaml')
- model = cfg.model
- model.forward = model.export
- ckpt_path = "best.pth"
- dummy_input = torch.randn(1, 3, 256, 256)
- export_onnx(
- model=model,
- ckpt=ckpt_path,
- inp=dummy_input,
- exp=Path(ckpt_path).with_suffix(".onnx"),
- input_names=['input_1'],
- output_names=['output_1'],
- opset=12,
- simplify=True)
|