InferenceNetworkUtils.cs 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. using System;
  2. using System.IO;
  3. using System.Collections.Generic;
  4. using Microsoft.ML.OnnxRuntime;
  5. namespace YOLODetectProcessLib
  6. {
  7. /// <summary>
  8. /// 推理网络共用的一些辅助的功能函数
  9. /// </summary>
  10. public class InferenceNetworkUtils
  11. {
  12. #region public
  13. /// <summary>
  14. /// 从文件中读出网络参数
  15. /// </summary>
  16. /// <param name="netsDir"></param>
  17. /// <param name="netFileName"></param>
  18. /// <param name="netHashCode"></param>
  19. /// <returns></returns>
  20. public static byte[] ReadNetworkDataFromFile(string netsDir, string netFileName, string netHashCode)
  21. {
  22. // 网络权值文件是否存在
  23. string netFilePath;
  24. if (Directory.Exists(netsDir))
  25. {
  26. netFilePath = Path.Combine(netsDir, netFileName);
  27. }
  28. else
  29. {
  30. netFilePath = netFileName;
  31. }
  32. if (!File.Exists(netFilePath))
  33. {
  34. throw new FileNotFoundException("Failed to load network:" + netFileName.ToString() + " from:" + netFilePath + ".");
  35. }
  36. // 读入(只读不写,不需要加锁)
  37. byte[] fileDataEncrypted = File.ReadAllBytes(netFilePath);
  38. // 检查模型文件的哈希值是否为当前版本所需
  39. string hashstr = HashCode.ComputeHashCode(fileDataEncrypted);
  40. if (hashstr != netHashCode)
  41. {
  42. throw new ArgumentException("netFilePath", "Unexpected parameter data file(" + netFilePath +
  43. ") for the current version of AIDiagSystem");
  44. }
  45. // 解密
  46. var fileDataDecrypted = AES.AESDecrypt(fileDataEncrypted);
  47. // 返回
  48. return fileDataDecrypted;
  49. }
  50. /// <summary>
  51. /// 检查Onnx模型的输入输出tensor名和尺寸是否符合要求
  52. /// </summary>
  53. /// <param name="inferSession"></param>
  54. /// <param name="inputTensorNames"></param>
  55. /// <param name="outputTensorNames"></param>
  56. /// <param name="inputTensorDimensions"></param>
  57. /// <param name="outputTensorDimensions"></param>
  58. public static void CheckOnnxModel(InferenceSession inferSession, string[] inputTensorNames, string[] outputTensorNames,
  59. List<int[]> inputTensorDimensions, List<int[]> outputTensorDimensions)
  60. {
  61. // 检查模型输入输出tensor名
  62. if (inferSession.InputMetadata.Count != inputTensorNames.Length)
  63. {
  64. throw new ArgumentOutOfRangeException("input", "The expected model input number is " + inputTensorNames.Length +
  65. ", but got " + inferSession.InputMetadata.Count + ".");
  66. }
  67. foreach (string inputTensorName in inputTensorNames)
  68. {
  69. if (!inferSession.InputMetadata.ContainsKey(inputTensorName))
  70. {
  71. throw new ArgumentOutOfRangeException("inputTensorName", "Find no input tensor with expected name of " + inputTensorName + ".");
  72. }
  73. if (inferSession.InputMetadata[inputTensorName].ElementType != typeof(float))
  74. {
  75. throw new ArgumentException("inputTensorType", "The expected input tensor type is float, but got " +
  76. inferSession.InputMetadata[inputTensorName].ElementType.ToString() + ".");
  77. }
  78. }
  79. //if (inferSession.OutputMetadata.Count != outputTensorNames.Length)
  80. //{
  81. // throw new ArgumentOutOfRangeException("output", "The expected model output number is " + outputTensorNames.Length +
  82. // ", but got " + inferSession.OutputMetadata.Count + ".");
  83. //}
  84. foreach (string outputTensorName in outputTensorNames)
  85. {
  86. if (!inferSession.OutputMetadata.ContainsKey(outputTensorName))
  87. {
  88. throw new ArgumentOutOfRangeException("outputTensorName", "Find no output tensor with expected name of " + outputTensorName + ".");
  89. }
  90. if (inferSession.OutputMetadata[outputTensorName].ElementType != typeof(float))
  91. {
  92. throw new ArgumentException("outputTensorType", "The expected output tensor type if float, but got " +
  93. inferSession.OutputMetadata[outputTensorName].ElementType.ToString() + ".");
  94. }
  95. }
  96. // 检查输入尺寸
  97. if (inputTensorDimensions.Count != inputTensorNames.Length)
  98. {
  99. throw new ArgumentException("inputTensorDimensions", "The inputTensorDimensions and inputTensorNames should have the same length.");
  100. }
  101. for (int ni = 0; ni < inputTensorNames.Length; ni++)
  102. {
  103. var inputTensorDimension = inferSession.InputMetadata[inputTensorNames[ni]].Dimensions;
  104. var inputTensorDimensionExpected = inputTensorDimensions[ni];
  105. if (inputTensorDimension.Length != inputTensorDimensionExpected.Length)
  106. {
  107. throw new ArgumentException("inputTensorDimensionLength", "The expected input tensor dimension length is " +
  108. inputTensorDimensionExpected.Length + ", but got " + inputTensorDimension.Length + ".");
  109. }
  110. if (inputTensorDimension.Length <= 1)
  111. {
  112. continue;
  113. }
  114. for (int nj = 0; nj < inputTensorDimension.Length; nj++)
  115. {
  116. if (inputTensorDimension[nj] != inputTensorDimensionExpected[nj] && inputTensorDimension[nj] != -1)
  117. {
  118. throw new ArgumentException("inputTensorDimension", "The expected input tensor dimension is " +
  119. inputTensorDimensionExpected[nj] + ", but got " + inputTensorDimension[nj] + " for the " + nj + "th axis.");
  120. }
  121. }
  122. }
  123. // 检查输出尺寸
  124. if (outputTensorDimensions.Count != outputTensorNames.Length)
  125. {
  126. throw new ArgumentException("outputTensorDimensions", "The outputTensorDimensions and outputTensorNames should have the same length.");
  127. }
  128. for (int ni = 0; ni < outputTensorNames.Length; ni++)
  129. {
  130. var outputTensorDimension = inferSession.OutputMetadata[outputTensorNames[ni]].Dimensions;
  131. var outputTensorDimensionExpected = outputTensorDimensions[ni];
  132. if (outputTensorDimension.Length != outputTensorDimensionExpected.Length)
  133. {
  134. throw new ArgumentException("outputTensorDimensionLength", "The expected output tensor dimension length is " +
  135. outputTensorDimensionExpected.Length + ", but got " + outputTensorDimension.Length + ".");
  136. }
  137. if (outputTensorDimension.Length <= 1)
  138. {
  139. continue;
  140. }
  141. for (int nj = 0; nj < outputTensorDimension.Length; nj++)
  142. {
  143. if (outputTensorDimension[nj] != outputTensorDimensionExpected[nj] && outputTensorDimension[nj] != -1)
  144. {
  145. throw new ArgumentException("outputTensorDimension", "The expected output tensor dimension is" +
  146. outputTensorDimensionExpected[nj] + ", but got " + outputTensorDimension[nj] + " for the " + nj + "th axis.");
  147. }
  148. }
  149. }
  150. }
  151. #endregion
  152. }
  153. }