export.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
  4. Format | `export.py --include` | Model
  5. --- | --- | ---
  6. PyTorch | - | yolov5s.pt
  7. TorchScript | `torchscript` | yolov5s.torchscript
  8. ONNX | `onnx` | yolov5s.onnx
  9. OpenVINO | `openvino` | yolov5s_openvino_model/
  10. TensorRT | `engine` | yolov5s.engine
  11. CoreML | `coreml` | yolov5s.mlmodel
  12. TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
  13. TensorFlow GraphDef | `pb` | yolov5s.pb
  14. TensorFlow Lite | `tflite` | yolov5s.tflite
  15. TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
  16. TensorFlow.js | `tfjs` | yolov5s_web_model/
  17. PaddlePaddle | `paddle` | yolov5s_paddle_model/
  18. Requirements:
  19. $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
  20. $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
  21. Usage:
  22. $ python export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
  23. Inference:
  24. $ python detect.py --weights yolov5s.pt # PyTorch
  25. yolov5s.torchscript # TorchScript
  26. yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
  27. yolov5s_openvino_model # OpenVINO
  28. yolov5s.engine # TensorRT
  29. yolov5s.mlmodel # CoreML (macOS-only)
  30. yolov5s_saved_model # TensorFlow SavedModel
  31. yolov5s.pb # TensorFlow GraphDef
  32. yolov5s.tflite # TensorFlow Lite
  33. yolov5s_edgetpu.tflite # TensorFlow Edge TPU
  34. yolov5s_paddle_model # PaddlePaddle
  35. TensorFlow.js:
  36. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  37. $ npm install
  38. $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
  39. $ npm start
  40. """
  41. import argparse
  42. import contextlib
  43. import json
  44. import os
  45. import platform
  46. import re
  47. import subprocess
  48. import sys
  49. import time
  50. import warnings
  51. from pathlib import Path
  52. import pandas as pd
  53. import torch
  54. from torch.utils.mobile_optimizer import optimize_for_mobile
  55. FILE = Path(__file__).resolve()
  56. ROOT = FILE.parents[0] # YOLOv5 root directory
  57. if str(ROOT) not in sys.path:
  58. sys.path.append(str(ROOT)) # add ROOT to PATH
  59. if platform.system() != "Windows":
  60. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  61. from models.experimental import attempt_load
  62. from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
  63. from utils.dataloaders import LoadImages
  64. from utils.general import (
  65. LOGGER,
  66. Profile,
  67. check_dataset,
  68. check_img_size,
  69. check_version,
  70. check_yaml,
  71. colorstr,
  72. file_size,
  73. get_default_args,
  74. print_args,
  75. url2file,
  76. yaml_save,
  77. )
  78. from utils.torch_utils import select_device, smart_inference_mode
  79. MACOS = platform.system() == "Darwin" # macOS environment
  80. class iOSModel(torch.nn.Module):
  81. def __init__(self, model, im):
  82. """Initializes an iOS compatible model with normalization based on image dimensions."""
  83. super().__init__()
  84. b, c, h, w = im.shape # batch, channel, height, width
  85. self.model = model
  86. self.nc = model.nc # number of classes
  87. if w == h:
  88. self.normalize = 1.0 / w
  89. else:
  90. self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
  91. # np = model(im)[0].shape[1] # number of points
  92. # self.normalize = torch.tensor([1. / w, 1. / h, 1. / w, 1. / h]).expand(np, 4) # explicit (faster, larger)
  93. def forward(self, x):
  94. """Runs forward pass on the input tensor, returning class confidences and normalized coordinates."""
  95. xywh, conf, cls = self.model(x)[0].squeeze().split((4, 1, self.nc), 1)
  96. return cls * conf, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
  97. def export_formats():
  98. """Returns a DataFrame of supported YOLOv5 model export formats and their properties."""
  99. x = [
  100. ["PyTorch", "-", ".pt", True, True],
  101. ["TorchScript", "torchscript", ".torchscript", True, True],
  102. ["ONNX", "onnx", ".onnx", True, True],
  103. ["OpenVINO", "openvino", "_openvino_model", True, False],
  104. ["TensorRT", "engine", ".engine", False, True],
  105. ["CoreML", "coreml", ".mlmodel", True, False],
  106. ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
  107. ["TensorFlow GraphDef", "pb", ".pb", True, True],
  108. ["TensorFlow Lite", "tflite", ".tflite", True, False],
  109. ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", False, False],
  110. ["TensorFlow.js", "tfjs", "_web_model", False, False],
  111. ["PaddlePaddle", "paddle", "_paddle_model", True, True],
  112. ]
  113. return pd.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
  114. def try_export(inner_func):
  115. """Decorator @try_export for YOLOv5 model export functions that logs success/failure, time taken, and file size."""
  116. inner_args = get_default_args(inner_func)
  117. def outer_func(*args, **kwargs):
  118. prefix = inner_args["prefix"]
  119. try:
  120. with Profile() as dt:
  121. f, model = inner_func(*args, **kwargs)
  122. LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)")
  123. return f, model
  124. except Exception as e:
  125. LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
  126. return None, None
  127. return outer_func
  128. @try_export
  129. def export_torchscript(model, im, file, optimize, prefix=colorstr("TorchScript:")):
  130. """Exports YOLOv5 model to TorchScript format, optionally optimized for mobile, with image shape and stride
  131. metadata.
  132. """
  133. LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
  134. f = file.with_suffix(".torchscript")
  135. ts = torch.jit.trace(model, im, strict=False)
  136. d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
  137. extra_files = {"config.txt": json.dumps(d)} # torch._C.ExtraFilesMap()
  138. if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
  139. optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
  140. else:
  141. ts.save(str(f), _extra_files=extra_files)
  142. return f, None
  143. @try_export
  144. def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr("ONNX:")):
  145. """Exports a YOLOv5 model to ONNX format with dynamic axes and optional simplification."""
  146. # check_requirements("onnx>=1.12.0")
  147. import onnx
  148. LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__}...")
  149. f = str(file.with_suffix(".onnx"))
  150. output_names = ["output0", "output1"] if isinstance(model, SegmentationModel) else ["output0"]
  151. if dynamic:
  152. dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
  153. if isinstance(model, SegmentationModel):
  154. dynamic["output0"] = {0: "batch", 1: "anchors"} # shape(1,25200,85)
  155. dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
  156. elif isinstance(model, DetectionModel):
  157. dynamic["output0"] = {0: "batch", 1: "anchors"} # shape(1,25200,85)
  158. torch.onnx.export(
  159. model.cpu() if dynamic else model, # --dynamic only compatible with cpu
  160. im.cpu() if dynamic else im,
  161. f,
  162. verbose=False,
  163. opset_version=opset,
  164. do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
  165. input_names=["images"],
  166. output_names=output_names,
  167. dynamic_axes=dynamic or None,
  168. )
  169. # Checks
  170. model_onnx = onnx.load(f) # load onnx model
  171. onnx.checker.check_model(model_onnx) # check onnx model
  172. # Metadata
  173. d = {"stride": int(max(model.stride)), "names": model.names}
  174. for k, v in d.items():
  175. meta = model_onnx.metadata_props.add()
  176. meta.key, meta.value = k, str(v)
  177. onnx.save(model_onnx, f)
  178. # Simplify
  179. if simplify:
  180. try:
  181. cuda = torch.cuda.is_available()
  182. # check_requirements(("onnxruntime-gpu" if cuda else "onnxruntime", "onnx-simplifier>=0.4.1"))
  183. import onnxsim
  184. LOGGER.info(f"{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...")
  185. model_onnx, check = onnxsim.simplify(model_onnx)
  186. assert check, "assert check failed"
  187. onnx.save(model_onnx, f)
  188. except Exception as e:
  189. LOGGER.info(f"{prefix} simplifier failure: {e}")
  190. return f, model_onnx
  191. @try_export
  192. def export_openvino(file, metadata, half, int8, data, prefix=colorstr("OpenVINO:")):
  193. # YOLOv5 OpenVINO export
  194. # check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
  195. import openvino.runtime as ov # noqa
  196. from openvino.tools import mo # noqa
  197. LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
  198. f = str(file).replace(file.suffix, f"_{'int8_' if int8 else ''}openvino_model{os.sep}")
  199. f_onnx = file.with_suffix(".onnx")
  200. f_ov = str(Path(f) / file.with_suffix(".xml").name)
  201. ov_model = mo.convert_model(f_onnx, model_name=file.stem, framework="onnx", compress_to_fp16=half) # export
  202. if int8:
  203. # check_requirements("nncf>=2.5.0") # requires at least version 2.5.0 to use the post-training quantization
  204. import nncf
  205. import numpy as np
  206. from utils.dataloaders import create_dataloader
  207. def gen_dataloader(yaml_path, task="train", imgsz=640, workers=4):
  208. data_yaml = check_yaml(yaml_path)
  209. data = check_dataset(data_yaml)
  210. dataloader = create_dataloader(
  211. data[task], imgsz=imgsz, batch_size=1, stride=32, pad=0.5, single_cls=False, rect=False, workers=workers
  212. )[0]
  213. return dataloader
  214. # noqa: F811
  215. def transform_fn(data_item):
  216. """
  217. Quantization transform function.
  218. Extracts and preprocess input data from dataloader item for quantization.
  219. Parameters:
  220. data_item: Tuple with data item produced by DataLoader during iteration
  221. Returns:
  222. input_tensor: Input data for quantization
  223. """
  224. assert data_item[0].dtype == torch.uint8, "input image must be uint8 for the quantization preprocessing"
  225. img = data_item[0].numpy().astype(np.float32) # uint8 to fp16/32
  226. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  227. return np.expand_dims(img, 0) if img.ndim == 3 else img
  228. ds = gen_dataloader(data)
  229. quantization_dataset = nncf.Dataset(ds, transform_fn)
  230. ov_model = nncf.quantize(ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED)
  231. ov.serialize(ov_model, f_ov) # save
  232. yaml_save(Path(f) / file.with_suffix(".yaml").name, metadata) # add metadata.yaml
  233. return f, None
  234. @try_export
  235. def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")):
  236. """Exports a YOLOv5 model to PaddlePaddle format using X2Paddle, saving to `save_dir` and adding a metadata.yaml
  237. file.
  238. """
  239. # check_requirements(("paddlepaddle", "x2paddle"))
  240. import x2paddle
  241. from x2paddle.convert import pytorch2paddle
  242. LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
  243. f = str(file).replace(".pt", f"_paddle_model{os.sep}")
  244. pytorch2paddle(module=model, save_dir=f, jit_type="trace", input_examples=[im]) # export
  245. yaml_save(Path(f) / file.with_suffix(".yaml").name, metadata) # add metadata.yaml
  246. return f, None
  247. @try_export
  248. def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
  249. """Exports YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support; requires coremltools."""
  250. # check_requirements("coremltools")
  251. import coremltools as ct
  252. LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
  253. f = file.with_suffix(".mlmodel")
  254. if nms:
  255. model = iOSModel(model, im)
  256. ts = torch.jit.trace(model, im, strict=False) # TorchScript model
  257. ct_model = ct.convert(ts, inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
  258. bits, mode = (8, "kmeans_lut") if int8 else (16, "linear") if half else (32, None)
  259. if bits < 32:
  260. if MACOS: # quantization only supported on macOS
  261. with warnings.catch_warnings():
  262. warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
  263. ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
  264. else:
  265. print(f"{prefix} quantization only supported on macOS, skipping...")
  266. ct_model.save(f)
  267. return f, ct_model
  268. @try_export
  269. def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr("TensorRT:")):
  270. """
  271. Exports a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0.
  272. https://developer.nvidia.com/tensorrt
  273. """
  274. assert im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`"
  275. try:
  276. import tensorrt as trt
  277. except Exception:
  278. if platform.system() == "Linux":
  279. check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
  280. import tensorrt as trt
  281. if trt.__version__[0] == "7": # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
  282. grid = model.model[-1].anchor_grid
  283. model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
  284. export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
  285. model.model[-1].anchor_grid = grid
  286. else: # TensorRT >= 8
  287. check_version(trt.__version__, "8.0.0", hard=True) # require tensorrt>=8.0.0
  288. export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
  289. onnx = file.with_suffix(".onnx")
  290. LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
  291. assert onnx.exists(), f"failed to export ONNX file: {onnx}"
  292. f = file.with_suffix(".engine") # TensorRT engine file
  293. logger = trt.Logger(trt.Logger.INFO)
  294. if verbose:
  295. logger.min_severity = trt.Logger.Severity.VERBOSE
  296. builder = trt.Builder(logger)
  297. config = builder.create_builder_config()
  298. config.max_workspace_size = workspace * 1 << 30
  299. # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
  300. flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  301. network = builder.create_network(flag)
  302. parser = trt.OnnxParser(network, logger)
  303. if not parser.parse_from_file(str(onnx)):
  304. raise RuntimeError(f"failed to load ONNX file: {onnx}")
  305. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  306. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  307. for inp in inputs:
  308. LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
  309. for out in outputs:
  310. LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
  311. if dynamic:
  312. if im.shape[0] <= 1:
  313. LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
  314. profile = builder.create_optimization_profile()
  315. for inp in inputs:
  316. profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
  317. config.add_optimization_profile(profile)
  318. LOGGER.info(f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}")
  319. if builder.platform_has_fast_fp16 and half:
  320. config.set_flag(trt.BuilderFlag.FP16)
  321. with builder.build_engine(network, config) as engine, open(f, "wb") as t:
  322. t.write(engine.serialize())
  323. return f, None
  324. @try_export
  325. def export_saved_model(
  326. model,
  327. im,
  328. file,
  329. dynamic,
  330. tf_nms=False,
  331. agnostic_nms=False,
  332. topk_per_class=100,
  333. topk_all=100,
  334. iou_thres=0.45,
  335. conf_thres=0.25,
  336. keras=False,
  337. prefix=colorstr("TensorFlow SavedModel:"),
  338. ):
  339. # YOLOv5 TensorFlow SavedModel export
  340. try:
  341. import tensorflow as tf
  342. except Exception:
  343. check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}<=2.15.1")
  344. import tensorflow as tf
  345. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  346. from models.tf import TFModel
  347. LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
  348. if tf.__version__ > "2.13.1":
  349. helper_url = "https://github.com/ultralytics/yolov5/issues/12489"
  350. LOGGER.info(
  351. f"WARNING ⚠️ using Tensorflow {tf.__version__} > 2.13.1 might cause issue when exporting the model to tflite {helper_url}"
  352. ) # handling issue https://github.com/ultralytics/yolov5/issues/12489
  353. f = str(file).replace(".pt", "_saved_model")
  354. batch_size, ch, *imgsz = list(im.shape) # BCHW
  355. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  356. im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
  357. _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  358. inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
  359. outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  360. keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
  361. keras_model.trainable = False
  362. keras_model.summary()
  363. if keras:
  364. keras_model.save(f, save_format="tf")
  365. else:
  366. spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
  367. m = tf.function(lambda x: keras_model(x)) # full model
  368. m = m.get_concrete_function(spec)
  369. frozen_func = convert_variables_to_constants_v2(m)
  370. tfm = tf.Module()
  371. tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
  372. tfm.__call__(im)
  373. tf.saved_model.save(
  374. tfm,
  375. f,
  376. options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)
  377. if check_version(tf.__version__, "2.6")
  378. else tf.saved_model.SaveOptions(),
  379. )
  380. return f, keras_model
  381. @try_export
  382. def export_pb(keras_model, file, prefix=colorstr("TensorFlow GraphDef:")):
  383. """Exports YOLOv5 model to TensorFlow GraphDef *.pb format; see https://github.com/leimao/Frozen_Graph_TensorFlow for details."""
  384. import tensorflow as tf
  385. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  386. LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
  387. f = file.with_suffix(".pb")
  388. m = tf.function(lambda x: keras_model(x)) # full model
  389. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  390. frozen_func = convert_variables_to_constants_v2(m)
  391. frozen_func.graph.as_graph_def()
  392. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  393. return f, None
  394. @try_export
  395. def export_tflite(
  396. keras_model, im, file, int8, per_tensor, data, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")
  397. ):
  398. # YOLOv5 TensorFlow Lite export
  399. import tensorflow as tf
  400. LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
  401. batch_size, ch, *imgsz = list(im.shape) # BCHW
  402. f = str(file).replace(".pt", "-fp16.tflite")
  403. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  404. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  405. converter.target_spec.supported_types = [tf.float16]
  406. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  407. if int8:
  408. from models.tf import representative_dataset_gen
  409. dataset = LoadImages(check_dataset(check_yaml(data))["train"], img_size=imgsz, auto=False)
  410. converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
  411. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  412. converter.target_spec.supported_types = []
  413. converter.inference_input_type = tf.uint8 # or tf.int8
  414. converter.inference_output_type = tf.uint8 # or tf.int8
  415. converter.experimental_new_quantizer = True
  416. if per_tensor:
  417. converter._experimental_disable_per_channel = True
  418. f = str(file).replace(".pt", "-int8.tflite")
  419. if nms or agnostic_nms:
  420. converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
  421. tflite_model = converter.convert()
  422. open(f, "wb").write(tflite_model)
  423. return f, None
  424. @try_export
  425. def export_edgetpu(file, prefix=colorstr("Edge TPU:")):
  426. """
  427. Exports a YOLOv5 model to Edge TPU compatible TFLite format; requires Linux and Edge TPU compiler.
  428. https://coral.ai/docs/edgetpu/models-intro/
  429. """
  430. cmd = "edgetpu_compiler --version"
  431. help_url = "https://coral.ai/docs/edgetpu/compiler/"
  432. assert platform.system() == "Linux", f"export only supported on Linux. See {help_url}"
  433. if subprocess.run(f"{cmd} > /dev/null 2>&1", shell=True).returncode != 0:
  434. LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
  435. sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system
  436. for c in (
  437. "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
  438. 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
  439. "sudo apt-get update",
  440. "sudo apt-get install edgetpu-compiler",
  441. ):
  442. subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)
  443. ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
  444. LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
  445. f = str(file).replace(".pt", "-int8_edgetpu.tflite") # Edge TPU model
  446. f_tfl = str(file).replace(".pt", "-int8.tflite") # TFLite model
  447. subprocess.run(
  448. [
  449. "edgetpu_compiler",
  450. "-s",
  451. "-d",
  452. "-k",
  453. "10",
  454. "--out_dir",
  455. str(file.parent),
  456. f_tfl,
  457. ],
  458. check=True,
  459. )
  460. return f, None
  461. @try_export
  462. def export_tfjs(file, int8, prefix=colorstr("TensorFlow.js:")):
  463. """Exports a YOLOv5 model to TensorFlow.js format, optionally with uint8 quantization."""
  464. check_requirements("tensorflowjs")
  465. import tensorflowjs as tfjs
  466. LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
  467. f = str(file).replace(".pt", "_web_model") # js dir
  468. f_pb = file.with_suffix(".pb") # *.pb path
  469. f_json = f"{f}/model.json" # *.json path
  470. args = [
  471. "tensorflowjs_converter",
  472. "--input_format=tf_frozen_model",
  473. "--quantize_uint8" if int8 else "",
  474. "--output_node_names=Identity,Identity_1,Identity_2,Identity_3",
  475. str(f_pb),
  476. f,
  477. ]
  478. subprocess.run([arg for arg in args if arg], check=True)
  479. json = Path(f_json).read_text()
  480. with open(f_json, "w") as j: # sort JSON Identity_* in ascending order
  481. subst = re.sub(
  482. r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
  483. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  484. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  485. r'"Identity.?.?": {"name": "Identity.?.?"}}}',
  486. r'{"outputs": {"Identity": {"name": "Identity"}, '
  487. r'"Identity_1": {"name": "Identity_1"}, '
  488. r'"Identity_2": {"name": "Identity_2"}, '
  489. r'"Identity_3": {"name": "Identity_3"}}}',
  490. json,
  491. )
  492. j.write(subst)
  493. return f, None
  494. def add_tflite_metadata(file, metadata, num_outputs):
  495. """
  496. Adds TFLite metadata to a model file, supporting multiple outputs, as specified by TensorFlow guidelines.
  497. https://www.tensorflow.org/lite/models/convert/metadata
  498. """
  499. with contextlib.suppress(ImportError):
  500. # check_requirements('tflite_support')
  501. from tflite_support import flatbuffers
  502. from tflite_support import metadata as _metadata
  503. from tflite_support import metadata_schema_py_generated as _metadata_fb
  504. tmp_file = Path("/tmp/meta.txt")
  505. with open(tmp_file, "w") as meta_f:
  506. meta_f.write(str(metadata))
  507. model_meta = _metadata_fb.ModelMetadataT()
  508. label_file = _metadata_fb.AssociatedFileT()
  509. label_file.name = tmp_file.name
  510. model_meta.associatedFiles = [label_file]
  511. subgraph = _metadata_fb.SubGraphMetadataT()
  512. subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
  513. subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
  514. model_meta.subgraphMetadata = [subgraph]
  515. b = flatbuffers.Builder(0)
  516. b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
  517. metadata_buf = b.Output()
  518. populator = _metadata.MetadataPopulator.with_model_file(file)
  519. populator.load_metadata_buffer(metadata_buf)
  520. populator.load_associated_files([str(tmp_file)])
  521. populator.populate()
  522. tmp_file.unlink()
  523. def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:")):
  524. """Converts a PyTorch YOLOv5 model to CoreML format with NMS, handling different input/output shapes and saving the
  525. model.
  526. """
  527. import coremltools as ct
  528. from PIL import Image
  529. print(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
  530. batch_size, ch, h, w = list(im.shape) # BCHW
  531. t = time.time()
  532. # YOLOv5 Output shapes
  533. spec = model.get_spec()
  534. out0, out1 = iter(spec.description.output)
  535. if platform.system() == "Darwin":
  536. img = Image.new("RGB", (w, h)) # img(192 width, 320 height)
  537. # img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
  538. out = model.predict({"image": img})
  539. out0_shape, out1_shape = out[out0.name].shape, out[out1.name].shape
  540. else: # linux and windows can not run model.predict(), get sizes from pytorch output y
  541. s = tuple(y[0].shape)
  542. out0_shape, out1_shape = (s[1], s[2] - 5), (s[1], 4) # (3780, 80), (3780, 4)
  543. # Checks
  544. nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
  545. na, nc = out0_shape
  546. # na, nc = out0.type.multiArrayType.shape # number anchors, classes
  547. assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
  548. # Define output shapes (missing)
  549. out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
  550. out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
  551. # spec.neuralNetwork.preprocessing[0].featureName = '0'
  552. # Flexible input shapes
  553. # from coremltools.models.neural_network import flexible_shape_utils
  554. # s = [] # shapes
  555. # s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
  556. # s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
  557. # flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
  558. # r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
  559. # r.add_height_range((192, 640))
  560. # r.add_width_range((192, 640))
  561. # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
  562. # Print
  563. print(spec.description)
  564. # Model from spec
  565. model = ct.models.MLModel(spec)
  566. # 3. Create NMS protobuf
  567. nms_spec = ct.proto.Model_pb2.Model()
  568. nms_spec.specificationVersion = 5
  569. for i in range(2):
  570. decoder_output = model._spec.description.output[i].SerializeToString()
  571. nms_spec.description.input.add()
  572. nms_spec.description.input[i].ParseFromString(decoder_output)
  573. nms_spec.description.output.add()
  574. nms_spec.description.output[i].ParseFromString(decoder_output)
  575. nms_spec.description.output[0].name = "confidence"
  576. nms_spec.description.output[1].name = "coordinates"
  577. output_sizes = [nc, 4]
  578. for i in range(2):
  579. ma_type = nms_spec.description.output[i].type.multiArrayType
  580. ma_type.shapeRange.sizeRanges.add()
  581. ma_type.shapeRange.sizeRanges[0].lowerBound = 0
  582. ma_type.shapeRange.sizeRanges[0].upperBound = -1
  583. ma_type.shapeRange.sizeRanges.add()
  584. ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
  585. ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
  586. del ma_type.shape[:]
  587. nms = nms_spec.nonMaximumSuppression
  588. nms.confidenceInputFeatureName = out0.name # 1x507x80
  589. nms.coordinatesInputFeatureName = out1.name # 1x507x4
  590. nms.confidenceOutputFeatureName = "confidence"
  591. nms.coordinatesOutputFeatureName = "coordinates"
  592. nms.iouThresholdInputFeatureName = "iouThreshold"
  593. nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
  594. nms.iouThreshold = 0.45
  595. nms.confidenceThreshold = 0.25
  596. nms.pickTop.perClass = True
  597. nms.stringClassLabels.vector.extend(names.values())
  598. nms_model = ct.models.MLModel(nms_spec)
  599. # 4. Pipeline models together
  600. pipeline = ct.models.pipeline.Pipeline(
  601. input_features=[
  602. ("image", ct.models.datatypes.Array(3, ny, nx)),
  603. ("iouThreshold", ct.models.datatypes.Double()),
  604. ("confidenceThreshold", ct.models.datatypes.Double()),
  605. ],
  606. output_features=["confidence", "coordinates"],
  607. )
  608. pipeline.add_model(model)
  609. pipeline.add_model(nms_model)
  610. # Correct datatypes
  611. pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
  612. pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
  613. pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
  614. # Update metadata
  615. pipeline.spec.specificationVersion = 5
  616. pipeline.spec.description.metadata.versionString = "https://github.com/ultralytics/yolov5"
  617. pipeline.spec.description.metadata.shortDescription = "https://github.com/ultralytics/yolov5"
  618. pipeline.spec.description.metadata.author = "glenn.jocher@ultralytics.com"
  619. pipeline.spec.description.metadata.license = "https://github.com/ultralytics/yolov5/blob/master/LICENSE"
  620. pipeline.spec.description.metadata.userDefined.update(
  621. {
  622. "classes": ",".join(names.values()),
  623. "iou_threshold": str(nms.iouThreshold),
  624. "confidence_threshold": str(nms.confidenceThreshold),
  625. }
  626. )
  627. # Save the model
  628. f = file.with_suffix(".mlmodel") # filename
  629. model = ct.models.MLModel(pipeline.spec)
  630. model.input_description["image"] = "Input image"
  631. model.input_description["iouThreshold"] = f"(optional) IOU Threshold override (default: {nms.iouThreshold})"
  632. model.input_description["confidenceThreshold"] = (
  633. f"(optional) Confidence Threshold override (default: {nms.confidenceThreshold})"
  634. )
  635. model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
  636. model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
  637. model.save(f) # pipelined
  638. print(f"{prefix} pipeline success ({time.time() - t:.2f}s), saved as {f} ({file_size(f):.1f} MB)")
  639. @smart_inference_mode()
  640. def run(
  641. data=ROOT / "data/coco128.yaml", # 'dataset.yaml path'
  642. weights=ROOT / "yolov5s.pt", # weights path
  643. imgsz=(640, 640), # image (height, width)
  644. batch_size=1, # batch size
  645. device="cpu", # cuda device, i.e. 0 or 0,1,2,3 or cpu
  646. include=("torchscript", "onnx"), # include formats
  647. half=False, # FP16 half-precision export
  648. inplace=False, # set YOLOv5 Detect() inplace=True
  649. keras=False, # use Keras
  650. optimize=False, # TorchScript: optimize for mobile
  651. int8=False, # CoreML/TF INT8 quantization
  652. per_tensor=False, # TF per tensor quantization
  653. dynamic=False, # ONNX/TF/TensorRT: dynamic axes
  654. simplify=False, # ONNX: simplify model
  655. opset=12, # ONNX: opset version
  656. verbose=False, # TensorRT: verbose log
  657. workspace=4, # TensorRT: workspace size (GB)
  658. nms=False, # TF: add NMS to model
  659. agnostic_nms=False, # TF: add agnostic NMS to model
  660. topk_per_class=100, # TF.js NMS: topk per class to keep
  661. topk_all=100, # TF.js NMS: topk for all classes to keep
  662. iou_thres=0.45, # TF.js NMS: IoU threshold
  663. conf_thres=0.25, # TF.js NMS: confidence threshold
  664. ):
  665. t = time.time()
  666. include = [x.lower() for x in include] # to lowercase
  667. fmts = tuple(export_formats()["Argument"][1:]) # --include arguments
  668. flags = [x in include for x in fmts]
  669. assert sum(flags) == len(include), f"ERROR: Invalid --include {include}, valid --include arguments are {fmts}"
  670. jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
  671. file = Path(url2file(weights) if str(weights).startswith(("http:/", "https:/")) else weights) # PyTorch weights
  672. # Load PyTorch model
  673. device = select_device(device)
  674. if half:
  675. assert device.type != "cpu" or coreml, "--half only compatible with GPU export, i.e. use --device 0"
  676. assert not dynamic, "--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both"
  677. model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
  678. # Checks
  679. imgsz *= 2 if len(imgsz) == 1 else 1 # expand
  680. if optimize:
  681. assert device.type == "cpu", "--optimize not compatible with cuda devices, i.e. use --device cpu"
  682. # Input
  683. gs = int(max(model.stride)) # grid size (max stride)
  684. imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
  685. im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
  686. # Update model
  687. model.eval()
  688. for k, m in model.named_modules():
  689. if isinstance(m, Detect):
  690. m.inplace = inplace
  691. m.dynamic = dynamic
  692. m.export = True
  693. for _ in range(2):
  694. y = model(im) # dry runs
  695. if half and not coreml:
  696. im, model = im.half(), model.half() # to FP16
  697. shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
  698. metadata = {"stride": int(max(model.stride)), "names": model.names} # model metadata
  699. LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
  700. # Exports
  701. f = [""] * len(fmts) # exported filenames
  702. warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
  703. if jit: # TorchScript
  704. f[0], _ = export_torchscript(model, im, file, optimize)
  705. if engine: # TensorRT required before ONNX
  706. f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
  707. if onnx or xml: # OpenVINO requires ONNX
  708. f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
  709. if xml: # OpenVINO
  710. f[3], _ = export_openvino(file, metadata, half, int8, data)
  711. if coreml: # CoreML
  712. f[4], ct_model = export_coreml(model, im, file, int8, half, nms)
  713. if nms:
  714. pipeline_coreml(ct_model, im, file, model.names, y)
  715. if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
  716. assert not tflite or not tfjs, "TFLite and TF.js models must be exported separately, please pass only one type."
  717. assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported."
  718. f[5], s_model = export_saved_model(
  719. model.cpu(),
  720. im,
  721. file,
  722. dynamic,
  723. tf_nms=nms or agnostic_nms or tfjs,
  724. agnostic_nms=agnostic_nms or tfjs,
  725. topk_per_class=topk_per_class,
  726. topk_all=topk_all,
  727. iou_thres=iou_thres,
  728. conf_thres=conf_thres,
  729. keras=keras,
  730. )
  731. if pb or tfjs: # pb prerequisite to tfjs
  732. f[6], _ = export_pb(s_model, file)
  733. if tflite or edgetpu:
  734. f[7], _ = export_tflite(
  735. s_model, im, file, int8 or edgetpu, per_tensor, data=data, nms=nms, agnostic_nms=agnostic_nms
  736. )
  737. if edgetpu:
  738. f[8], _ = export_edgetpu(file)
  739. add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
  740. if tfjs:
  741. f[9], _ = export_tfjs(file, int8)
  742. if paddle: # PaddlePaddle
  743. f[10], _ = export_paddle(model, im, file, metadata)
  744. # Finish
  745. f = [str(x) for x in f if x] # filter out '' and None
  746. if any(f):
  747. cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
  748. det &= not seg # segmentation models inherit from SegmentationModel(DetectionModel)
  749. dir = Path("segment" if seg else "classify" if cls else "")
  750. h = "--half" if half else "" # --half FP16 inference arg
  751. s = (
  752. "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference"
  753. if cls
  754. else "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference"
  755. if seg
  756. else ""
  757. )
  758. LOGGER.info(
  759. f'\nExport complete ({time.time() - t:.1f}s)'
  760. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  761. f"\nDetect: python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"
  762. f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}"
  763. f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}"
  764. f'\nVisualize: https://netron.app'
  765. )
  766. return f # return list of exported files/dirs
  767. def parse_opt(known=False):
  768. """Parses command-line arguments for YOLOv5 model export configurations, returning the parsed options."""
  769. parser = argparse.ArgumentParser()
  770. parser.add_argument("--data", type=str, default=ROOT / "data/Neck-Organ-Seg.yaml", help="dataset.yaml path")
  771. parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "runs/train/exp56/weights/best.pt", help="model.pt path(s)")
  772. parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[320, 320], help="image (h, w)")
  773. parser.add_argument("--batch-size", type=int, default=1, help="batch size")
  774. parser.add_argument("--device", default="cpu", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
  775. parser.add_argument("--half", action="store_true", help="FP16 half-precision export")
  776. parser.add_argument("--inplace", action="store_true", help="set YOLOv5 Detect() inplace=True")
  777. parser.add_argument("--keras", action="store_true", help="TF: use Keras")
  778. parser.add_argument("--optimize", action="store_true", help="TorchScript: optimize for mobile")
  779. parser.add_argument("--int8", action="store_true", help="CoreML/TF/OpenVINO INT8 quantization")
  780. parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
  781. parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
  782. parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
  783. parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
  784. parser.add_argument("--verbose", action="store_true", help="TensorRT: verbose log")
  785. parser.add_argument("--workspace", type=int, default=4, help="TensorRT: workspace size (GB)")
  786. parser.add_argument("--nms", action="store_true", help="TF: add NMS to model")
  787. parser.add_argument("--agnostic-nms", action="store_true", help="TF: add agnostic NMS to model")
  788. parser.add_argument("--topk-per-class", type=int, default=100, help="TF.js NMS: topk per class to keep")
  789. parser.add_argument("--topk-all", type=int, default=100, help="TF.js NMS: topk for all classes to keep")
  790. parser.add_argument("--iou-thres", type=float, default=0.45, help="TF.js NMS: IoU threshold")
  791. parser.add_argument("--conf-thres", type=float, default=0.25, help="TF.js NMS: confidence threshold")
  792. parser.add_argument(
  793. "--include",
  794. nargs="+",
  795. default=["onnx"],
  796. help="torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle",
  797. )
  798. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  799. print_args(vars(opt))
  800. return opt
  801. def main(opt):
  802. """Executes the YOLOv5 model inference or export with specified weights and options."""
  803. for opt.weights in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
  804. run(**vars(opt))
  805. if __name__ == "__main__":
  806. opt = parse_opt()
  807. main(opt)