123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- using System;
- using System.IO;
- using System.Collections.Generic;
- using Microsoft.ML.OnnxRuntime;
- namespace YOLODetectProcessLib
- {
- /// <summary>
- /// 推理网络共用的一些辅助的功能函数
- /// </summary>
- public class InferenceNetworkUtils
- {
- #region public
- /// <summary>
- /// 从文件中读出网络参数
- /// </summary>
- /// <param name="netsDir"></param>
- /// <param name="netFileName"></param>
- /// <param name="netHashCode"></param>
- /// <returns></returns>
- public static byte[] ReadNetworkDataFromFile(string netsDir, string netFileName, string netHashCode)
- {
- // 网络权值文件是否存在
- string netFilePath;
- if (Directory.Exists(netsDir))
- {
- netFilePath = Path.Combine(netsDir, netFileName);
- }
- else
- {
- netFilePath = netFileName;
- }
- if (!File.Exists(netFilePath))
- {
- throw new FileNotFoundException("Failed to load network:" + netFileName.ToString() + " from:" + netFilePath + ".");
- }
- // 读入(只读不写,不需要加锁)
- byte[] fileDataEncrypted = File.ReadAllBytes(netFilePath);
- // 检查模型文件的哈希值是否为当前版本所需
- string hashstr = HashCode.ComputeHashCode(fileDataEncrypted);
- if (hashstr != netHashCode)
- {
- throw new ArgumentException("netFilePath", "Unexpected parameter data file(" + netFilePath +
- ") for the current version of AIDiagSystem");
- }
- // 解密
- var fileDataDecrypted = AES.AESDecrypt(fileDataEncrypted);
- // 返回
- return fileDataDecrypted;
- }
- /// <summary>
- /// 检查Onnx模型的输入输出tensor名和尺寸是否符合要求
- /// </summary>
- /// <param name="inferSession"></param>
- /// <param name="inputTensorNames"></param>
- /// <param name="outputTensorNames"></param>
- /// <param name="inputTensorDimensions"></param>
- /// <param name="outputTensorDimensions"></param>
- public static void CheckOnnxModel(InferenceSession inferSession, string[] inputTensorNames, string[] outputTensorNames,
- List<int[]> inputTensorDimensions, List<int[]> outputTensorDimensions)
- {
- // 检查模型输入输出tensor名
- if (inferSession.InputMetadata.Count != inputTensorNames.Length)
- {
- throw new ArgumentOutOfRangeException("input", "The expected model input number is " + inputTensorNames.Length +
- ", but got " + inferSession.InputMetadata.Count + ".");
- }
- foreach (string inputTensorName in inputTensorNames)
- {
- if (!inferSession.InputMetadata.ContainsKey(inputTensorName))
- {
- throw new ArgumentOutOfRangeException("inputTensorName", "Find no input tensor with expected name of " + inputTensorName + ".");
- }
- if (inferSession.InputMetadata[inputTensorName].ElementType != typeof(float))
- {
- throw new ArgumentException("inputTensorType", "The expected input tensor type is float, but got " +
- inferSession.InputMetadata[inputTensorName].ElementType.ToString() + ".");
- }
- }
- //if (inferSession.OutputMetadata.Count != outputTensorNames.Length)
- //{
- // throw new ArgumentOutOfRangeException("output", "The expected model output number is " + outputTensorNames.Length +
- // ", but got " + inferSession.OutputMetadata.Count + ".");
- //}
- foreach (string outputTensorName in outputTensorNames)
- {
- if (!inferSession.OutputMetadata.ContainsKey(outputTensorName))
- {
- throw new ArgumentOutOfRangeException("outputTensorName", "Find no output tensor with expected name of " + outputTensorName + ".");
- }
- if (inferSession.OutputMetadata[outputTensorName].ElementType != typeof(float))
- {
- throw new ArgumentException("outputTensorType", "The expected output tensor type if float, but got " +
- inferSession.OutputMetadata[outputTensorName].ElementType.ToString() + ".");
- }
- }
- // 检查输入尺寸
- if (inputTensorDimensions.Count != inputTensorNames.Length)
- {
- throw new ArgumentException("inputTensorDimensions", "The inputTensorDimensions and inputTensorNames should have the same length.");
- }
- for (int ni = 0; ni < inputTensorNames.Length; ni++)
- {
- var inputTensorDimension = inferSession.InputMetadata[inputTensorNames[ni]].Dimensions;
- var inputTensorDimensionExpected = inputTensorDimensions[ni];
- if (inputTensorDimension.Length != inputTensorDimensionExpected.Length)
- {
- throw new ArgumentException("inputTensorDimensionLength", "The expected input tensor dimension length is " +
- inputTensorDimensionExpected.Length + ", but got " + inputTensorDimension.Length + ".");
- }
- if (inputTensorDimension.Length <= 1)
- {
- continue;
- }
- for (int nj = 0; nj < inputTensorDimension.Length; nj++)
- {
- if (inputTensorDimension[nj] != inputTensorDimensionExpected[nj] && inputTensorDimension[nj] != -1)
- {
- throw new ArgumentException("inputTensorDimension", "The expected input tensor dimension is " +
- inputTensorDimensionExpected[nj] + ", but got " + inputTensorDimension[nj] + " for the " + nj + "th axis.");
- }
- }
- }
- // 检查输出尺寸
- if (outputTensorDimensions.Count != outputTensorNames.Length)
- {
- throw new ArgumentException("outputTensorDimensions", "The outputTensorDimensions and outputTensorNames should have the same length.");
- }
- for (int ni = 0; ni < outputTensorNames.Length; ni++)
- {
- var outputTensorDimension = inferSession.OutputMetadata[outputTensorNames[ni]].Dimensions;
- var outputTensorDimensionExpected = outputTensorDimensions[ni];
- if (outputTensorDimension.Length != outputTensorDimensionExpected.Length)
- {
- throw new ArgumentException("outputTensorDimensionLength", "The expected output tensor dimension length is " +
- outputTensorDimensionExpected.Length + ", but got " + outputTensorDimension.Length + ".");
- }
- if (outputTensorDimension.Length <= 1)
- {
- continue;
- }
- for (int nj = 0; nj < outputTensorDimension.Length; nj++)
- {
- if (outputTensorDimension[nj] != outputTensorDimensionExpected[nj] && outputTensorDimension[nj] != -1)
- {
- throw new ArgumentException("outputTensorDimension", "The expected output tensor dimension is" +
- outputTensorDimensionExpected[nj] + ", but got " + outputTensorDimension[nj] + " for the " + nj + "th axis.");
- }
- }
- }
- }
- #endregion
- }
- }
|