InferNetOnnxDetectBasic.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. using System;
  2. using System.Collections.Generic;
  3. using Microsoft.ML.OnnxRuntime;
  4. using Microsoft.ML.OnnxRuntime.Tensors;
  5. using System.Runtime.InteropServices;
  6. using System.IO;
  7. using System.Runtime.InteropServices;
  8. namespace YOLODetectProcessLib
  9. {
  10. public abstract class InferNetOnnxDetectBasic : IInferenceNetwork
  11. {
  12. #region dll import
  13. public enum YoloType
  14. {
  15. yolov5,
  16. yolov8
  17. }
  18. /// <summary>
  19. /// cpp dll 中筛选轮廓的参数的struct
  20. /// </summary>
  21. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  22. public struct YoloDetectInput
  23. {
  24. public YoloType YoloType; // yolo后处理的类型,目前支持yolov5和yolov8
  25. public IntPtr Alldetectresult; // 输入的所有结果:(batch_size, allboxesnum, 5+classesnum) float32型
  26. public IntPtr Outputresult; // (batch_size,maxdet,6) float32型
  27. public IntPtr Errormsg; // 错误信息
  28. public float Confthres; // box框的置信度,默认值0.6
  29. //默认的confthres是box的置信度,添加一个类别的置信度,之前都是默认一致的
  30. public float Clsconfthres; // 类别的置信度,默认值0.5
  31. public int Batchsize; // batchsize,默认值1
  32. public int Allboxesnum; //输入的所有boxes数量
  33. public int Classesnum; //类别数量,去除背景类
  34. public int Maxdet; //最终输出结果的最大数量,默认是10
  35. public int Errormsgmaxlen; // 错误信息最大长度(超过该长度则不复制)
  36. public int Modelinputheight; //模型输入高度,用于去除面积较小的box,默认值320
  37. public int Modelinputwidth; //模型输入宽度,用于去除面积较小的box,默认值320
  38. public float Minboxratio; //面积太小的box会删除,box面积占图像面积的百分比,默认是0.01,低于该值的box去除
  39. //ApplyPostProcessToBBox中,用于单个类别的box筛选
  40. public int Postprocesstopk; //ApplyPostProcessToBBox中选取的topk数量,默认是5
  41. public float Ioufltth; //iou滤除参数,默认值0.3
  42. public float Iosfltth; //ios滤除参数,默认值0.3
  43. //FindBoxesToUnion中 当两框重叠率高于一定程度,且合并后增加面积并不多,则合并两框
  44. public float Unioniouth; // union参数,默认1.f
  45. public float Unioniosth; // union参数,默认1.f
  46. public float Unionuobth; // union参数,默认1.f
  47. //ApplyBoxClassFilter中,同一幅图上有多个框,用于不用类别的box筛选
  48. public float Ioufltthdiffcls; //ApplyBoxClassFilter中, iou滤除参数,默认值0.3
  49. public float Iosfltthdiffcls; //ApplyBoxClassFilter中, ios滤除参数,默认值0.3
  50. [MarshalAs(UnmanagedType.I1)]
  51. public bool Enableioufilt; // ApplyPostProcessToBBox中,是否启用iou滤除,默认为true (当两框IOU大于iouflttht时,舍弃分值低的框)
  52. [MarshalAs(UnmanagedType.I1)]
  53. public bool Enableiosfilt; //ApplyPostProcessToBBox中,是否启用ios滤除,默认为false(当两框IOS大于iosfltth时,舍弃分值低的框)
  54. [MarshalAs(UnmanagedType.I1)]
  55. public bool Enableunion; // FindBoxesToUnion中是否开启框合并,默认false
  56. [MarshalAs(UnmanagedType.I1)]
  57. public bool Enableioufiltdiffcls; //ApplyBoxClassFilter中, 是否启用iou滤除,默认为true (当两框IOU大于iouflttht时,舍弃分值低的框)
  58. [MarshalAs(UnmanagedType.I1)]
  59. public bool Enableiosfiltdiffcls; //ApplyBoxClassFilter中, 是否启用ios滤除,默认为false(当两框IOS大于iosfltth时,舍弃分值低的框)
  60. };
  61. [DllImport(@"YOLOOutputPostProcessUtil.dll", CallingConvention = CallingConvention.Cdecl)]
  62. [return: MarshalAs(UnmanagedType.I1)]
  63. public static extern bool SelectNeededBoxes(ref YoloDetectInput input);
  64. #endregion
  65. #region protected
  66. SessionOptions _sessionOption = null;
  67. protected volatile bool _modelLoaded;
  68. private InferenceSession _inferSession;
  69. private object _sessLocker = new object();
  70. protected int[] _inputTensorShape = null;
  71. protected int[] _outputTensorShape = null;
  72. private FixedBufferOnnxValue _inputValue;
  73. private FixedBufferOnnxValue _outputValue;
  74. private int _modelInputH = 320;
  75. private int _modelInputW = 320;
  76. private int _modelInputC = 3;
  77. private string[] _inputTensorNames = new[] { "images" };
  78. private string[] _outputTensorNames = new[] { "output" };
  79. protected readonly EnumResizeMode _resizeMode = EnumResizeMode.Warp;
  80. protected readonly EnumMeanValueType _meanValueType = EnumMeanValueType.None;
  81. protected readonly float _meanValueR = 0;
  82. protected readonly float _meanValueG = 0;
  83. protected readonly float _meanValueB = 0;
  84. protected readonly EnumScaleValueType _scaleValueType = EnumScaleValueType.ConstantScale;
  85. protected readonly float _scaleValueR = 255;
  86. protected readonly float _scaleValueG = 255;
  87. protected readonly float _scaleValueB = 255;
  88. protected readonly EnumNormalizationType _normType = EnumNormalizationType.None;
  89. protected readonly EnumChannelOrder _channelOrder = EnumChannelOrder.RGB;
  90. protected readonly EnumAxisOrder _axisOrder = EnumAxisOrder.CHW;
  91. private float[] _detectedResultData;
  92. protected MoldedImage _moldedImage = null;
  93. protected char[] _errormsg;
  94. protected float _confthres = 0.6f;
  95. protected float _clsconfthres = 0.3f;
  96. protected int _batchsize = 1;
  97. protected int _allboxesnum = 8400;
  98. protected int _classesnum = 7;
  99. protected int _maxdet = 5;
  100. protected int _errormsgmaxlen = 256;
  101. protected int _modelinputheight = 320;
  102. protected int _modelinputwidth = 320;
  103. protected float _minboxratio = 0.001f;
  104. protected int _postprocesstopk = 20;
  105. protected bool _enableioufilt = true;
  106. protected bool _enableiosfilt = true;
  107. protected float _ioufltth = 0.3f;
  108. protected float _iosfltth = 0.3f;
  109. protected float[] _outputresultData;
  110. protected bool _enableunion = false;
  111. protected float _unioniouth = 1.0f;
  112. protected float _unioniosth = 1.0f;
  113. protected float _unionuobth = 1.0f;
  114. protected bool _enableioufiltdiffcls = true;
  115. protected bool _enableiosfiltdiffcls = true;
  116. protected float _ioufltthdiffcls = 0.01f;
  117. protected float _iosfltthdiffcls = 0.01f;
  118. protected YoloType _yolotype = YoloType.yolov5;
  119. #endregion
  120. /// <summary>
  121. /// 哈希值
  122. /// </summary>
  123. public abstract string HashCode { get; }
  124. /// <summary>
  125. /// 网络名
  126. /// </summary>
  127. public abstract string NetworkName { get; }
  128. /// <summary>
  129. /// 模型是否已加载
  130. /// </summary>
  131. public bool NetworkLoaded => _modelLoaded;
  132. /// <summary>
  133. /// 通知订阅者,推理过程中发生了错误
  134. /// </summary>
  135. public event EventHandler<ErrorEventArgs> NotifyError;
  136. /// <summary>
  137. /// 加载模型
  138. /// </summary>
  139. /// <param name="numCPU"></param>
  140. /// <param name="netDirU"></param>
  141. public virtual void LoadNetwork(int numCPU, string netDir)
  142. {
  143. // 不重复加载
  144. if (_modelLoaded)
  145. {
  146. return;
  147. }
  148. byte[] trainedNetwork = InferenceNetworkUtils.ReadNetworkDataFromFile(netDir, NetworkName, HashCode);
  149. _sessionOption = new SessionOptions();
  150. _sessionOption.InterOpNumThreads = numCPU;
  151. _sessionOption.IntraOpNumThreads = numCPU;
  152. lock (_sessLocker)
  153. {
  154. _inferSession = new InferenceSession(trainedNetwork, _sessionOption);
  155. }
  156. // 检查模型输入输出tensor名和尺寸
  157. _inputTensorShape = new int[] { 1, _modelInputC, _modelInputH, _modelInputW };
  158. _outputTensorShape = new int[] { 1, _allboxesnum, _classesnum + 5 };
  159. var inputTensorDimensions = new List<int[]> { _inputTensorShape };
  160. var outputTensorDimensions = new List<int[]> { _outputTensorShape };
  161. InferenceNetworkUtils.CheckOnnxModel(_inferSession, _inputTensorNames, _outputTensorNames, inputTensorDimensions, outputTensorDimensions);
  162. // 创建moldedImage
  163. _moldedImage = new MoldedImage(_modelInputH, _modelInputW, _modelInputC, _resizeMode, _meanValueType,
  164. _meanValueR, _meanValueG, _meanValueB, _scaleValueType, _scaleValueR, _scaleValueG, _scaleValueB,
  165. _normType, _channelOrder, _axisOrder);
  166. _detectedResultData = new float[_allboxesnum * (_classesnum + 5)];
  167. var tensorInput = new DenseTensor<float>(_moldedImage.DataBuffer, _inputTensorShape);
  168. var tensorOutput = new DenseTensor<float>(_detectedResultData, _outputTensorShape);
  169. _inputValue = FixedBufferOnnxValue.CreateFromTensor(tensorInput);
  170. _outputValue = FixedBufferOnnxValue.CreateFromTensor(tensorOutput);
  171. _modelLoaded = true;
  172. }
  173. /// <summary>
  174. /// 进行推理
  175. /// </summary>
  176. /// <param name="images"></param>
  177. /// <returns></returns>
  178. public virtual IDetectedObject[][] Process(InferenceNetworkInputImage[] images)
  179. {
  180. try
  181. {
  182. int batchSize = images.Length;
  183. IDetectedObject[][] results = new IDetectedObject[batchSize][];
  184. for (int ni = 0; ni < batchSize; ni++)
  185. {
  186. _moldedImage.Process(images[ni]);
  187. var inputValues = new[] { _inputValue };
  188. var outputValues = new[] { _outputValue };
  189. lock (_sessLocker)
  190. {
  191. _inferSession.Run(_inputTensorNames, inputValues, _outputTensorNames, outputValues);
  192. }
  193. // 后处理
  194. var result = DetectionPostProcess((int)(images[ni].ROI.Right - images[ni].ROI.Left),
  195. (int)(images[ni].ROI.Bottom - images[ni].ROI.Top),
  196. (int)images[ni].ROI.Left,
  197. (int)images[ni].ROI.Top);
  198. results[ni] = result;
  199. }
  200. return results;
  201. }
  202. catch (Exception excep)
  203. {
  204. return null;
  205. }
  206. }
  207. public IDetectedObject[] DetectionPostProcess(int RoiWidth, int RoiHeight, int RoiLeft, int RoiTop)
  208. {
  209. List<IDetectedObject> results = new List<IDetectedObject>();
  210. YoloDetectInput yoloinput = new YoloDetectInput();
  211. yoloinput.YoloType = _yolotype;
  212. GCHandle hdetectresult = GCHandle.Alloc(_detectedResultData, GCHandleType.Pinned);
  213. IntPtr pdetectdata = hdetectresult.AddrOfPinnedObject();
  214. yoloinput.Alldetectresult = pdetectdata;
  215. _outputresultData = new float[_batchsize * _maxdet * 6];
  216. GCHandle houtputresult = GCHandle.Alloc(_outputresultData, GCHandleType.Pinned);
  217. IntPtr poutputdata = houtputresult.AddrOfPinnedObject();
  218. yoloinput.Outputresult = poutputdata;
  219. _errormsg = new char[_errormsgmaxlen];
  220. GCHandle herrormsg = GCHandle.Alloc(_errormsg, GCHandleType.Pinned);
  221. IntPtr perrormsgdata = herrormsg.AddrOfPinnedObject();
  222. yoloinput.Errormsg = perrormsgdata;
  223. yoloinput.Confthres = _confthres;
  224. yoloinput.Clsconfthres = _clsconfthres;
  225. yoloinput.Batchsize = _batchsize;
  226. yoloinput.Allboxesnum = _allboxesnum;
  227. yoloinput.Classesnum = _classesnum;
  228. yoloinput.Maxdet = _maxdet;
  229. yoloinput.Errormsgmaxlen = _errormsgmaxlen;
  230. yoloinput.Modelinputheight = _modelinputheight;
  231. yoloinput.Modelinputwidth = _modelinputwidth;
  232. yoloinput.Minboxratio = _minboxratio;
  233. yoloinput.Postprocesstopk = _postprocesstopk;
  234. yoloinput.Enableioufilt = _enableioufilt;
  235. yoloinput.Enableiosfilt = _enableiosfilt;
  236. yoloinput.Ioufltth = _ioufltth;
  237. yoloinput.Iosfltth = _iosfltth;
  238. yoloinput.Enableunion = _enableunion;
  239. yoloinput.Unioniouth = _unioniouth;
  240. yoloinput.Unioniosth = _unioniosth;
  241. yoloinput.Unionuobth = _unionuobth;
  242. yoloinput.Enableioufiltdiffcls = _enableioufiltdiffcls;
  243. yoloinput.Enableiosfiltdiffcls = _enableiosfiltdiffcls;
  244. yoloinput.Ioufltthdiffcls = _ioufltthdiffcls;
  245. yoloinput.Iosfltthdiffcls = _iosfltthdiffcls;
  246. bool ret = SelectNeededBoxes(ref yoloinput);
  247. if (!ret)
  248. {
  249. NotifyError?.Invoke(this, new ErrorEventArgs(new Exception("Failed at calling PostProcessOneMaskForSemanticSeg")));
  250. }
  251. float[] outs = _outputresultData;
  252. //只支持单幅图像
  253. if (_batchsize != 1)
  254. {
  255. NotifyError?.Invoke(this, new ErrorEventArgs(new Exception("batchsize must be 1")));
  256. }
  257. for (int ni = 0; ni < _maxdet; ni++)
  258. {
  259. int label = (int)outs[ni * 6];
  260. float conf = outs[ni * 6 + 1];
  261. float left = outs[ni * 6 + 2] * _modelinputwidth;
  262. float top = outs[ni * 6 + 3] * _modelinputheight;
  263. float right = outs[ni * 6 + 4] * _modelinputwidth;
  264. float bottom = outs[ni * 6 + 5] * _modelinputheight;
  265. if (conf != 0.0f && (int)right != 0)
  266. {
  267. //// 将box转回原始图像坐标系
  268. float scaleW = (float)RoiWidth / _modelInputW;
  269. float scaleH = (float)RoiHeight / _modelInputH;
  270. // 先将坐标转换到原始图像的ROI范围内
  271. float leftRoi = (left - 0) * scaleW;
  272. float rightRoi = (right - 0) * scaleW;
  273. float topRoi = (top - 0) * scaleH;
  274. float bottomRoi = (bottom - 0) * scaleH;
  275. // 不要超过原来的ROI范围
  276. leftRoi = Math.Min(Math.Max(leftRoi, 0), RoiWidth);
  277. rightRoi = Math.Min(Math.Max(rightRoi, 0), RoiWidth);
  278. topRoi = Math.Min(Math.Max(topRoi, 0), RoiHeight);
  279. bottomRoi = Math.Min(Math.Max(bottomRoi, 0), RoiHeight);
  280. // 再转换到原始图像坐标系
  281. float leftOrig = leftRoi + RoiLeft;
  282. float rightOrig = rightRoi + RoiLeft;
  283. float topOrig = topRoi + RoiTop;
  284. float bottomOrig = bottomRoi + RoiTop;
  285. Rect rect = new Rect((int)leftOrig, (int)topOrig, (int)rightOrig - (int)leftOrig,(int)bottomOrig - (int)topOrig);
  286. DetectedObject detectedObject = new DetectedObject(label, conf, rect);
  287. results.Add(detectedObject);
  288. }
  289. }
  290. return results.ToArray();
  291. }
  292. /// <summary>
  293. /// 销毁
  294. /// </summary>
  295. public virtual void Dispose()
  296. {
  297. DoDispose();
  298. GC.SuppressFinalize(this);
  299. }
  300. /// <summary>
  301. /// 析构
  302. /// </summary>
  303. ~InferNetOnnxDetectBasic()
  304. {
  305. DoDispose();
  306. }
  307. #region private
  308. private void DoDispose()
  309. {
  310. lock (_sessLocker)
  311. {
  312. _inferSession?.Dispose();
  313. _inferSession = null;
  314. }
  315. _inputValue?.Dispose();
  316. _inputValue = null;
  317. _outputValue?.Dispose();
  318. _outputValue = null;
  319. _moldedImage?.Dispose();
  320. }
  321. #endregion
  322. }
  323. }