using System;
using System.IO;
using System.Collections.Generic;
using Microsoft.ML.OnnxRuntime;
namespace YOLODetectProcessLib
{
///
/// 推理网络共用的一些辅助的功能函数
///
public class InferenceNetworkUtils
{
#region public
///
/// 从文件中读出网络参数
///
///
///
///
///
public static byte[] ReadNetworkDataFromFile(string netsDir, string netFileName, string netHashCode)
{
// 网络权值文件是否存在
string netFilePath;
if (Directory.Exists(netsDir))
{
netFilePath = Path.Combine(netsDir, netFileName);
}
else
{
netFilePath = netFileName;
}
if (!File.Exists(netFilePath))
{
throw new FileNotFoundException("Failed to load network:" + netFileName.ToString() + " from:" + netFilePath + ".");
}
// 读入(只读不写,不需要加锁)
byte[] fileDataEncrypted = File.ReadAllBytes(netFilePath);
// 检查模型文件的哈希值是否为当前版本所需
string hashstr = HashCode.ComputeHashCode(fileDataEncrypted);
if (hashstr != netHashCode)
{
throw new ArgumentException("netFilePath", "Unexpected parameter data file(" + netFilePath +
") for the current version of AIDiagSystem");
}
// 解密
var fileDataDecrypted = AES.AESDecrypt(fileDataEncrypted);
// 返回
return fileDataDecrypted;
}
///
/// 检查Onnx模型的输入输出tensor名和尺寸是否符合要求
///
///
///
///
///
///
public static void CheckOnnxModel(InferenceSession inferSession, string[] inputTensorNames, string[] outputTensorNames,
List inputTensorDimensions, List outputTensorDimensions)
{
// 检查模型输入输出tensor名
if (inferSession.InputMetadata.Count != inputTensorNames.Length)
{
throw new ArgumentOutOfRangeException("input", "The expected model input number is " + inputTensorNames.Length +
", but got " + inferSession.InputMetadata.Count + ".");
}
foreach (string inputTensorName in inputTensorNames)
{
if (!inferSession.InputMetadata.ContainsKey(inputTensorName))
{
throw new ArgumentOutOfRangeException("inputTensorName", "Find no input tensor with expected name of " + inputTensorName + ".");
}
if (inferSession.InputMetadata[inputTensorName].ElementType != typeof(float))
{
throw new ArgumentException("inputTensorType", "The expected input tensor type is float, but got " +
inferSession.InputMetadata[inputTensorName].ElementType.ToString() + ".");
}
}
//if (inferSession.OutputMetadata.Count != outputTensorNames.Length)
//{
// throw new ArgumentOutOfRangeException("output", "The expected model output number is " + outputTensorNames.Length +
// ", but got " + inferSession.OutputMetadata.Count + ".");
//}
foreach (string outputTensorName in outputTensorNames)
{
if (!inferSession.OutputMetadata.ContainsKey(outputTensorName))
{
throw new ArgumentOutOfRangeException("outputTensorName", "Find no output tensor with expected name of " + outputTensorName + ".");
}
if (inferSession.OutputMetadata[outputTensorName].ElementType != typeof(float))
{
throw new ArgumentException("outputTensorType", "The expected output tensor type if float, but got " +
inferSession.OutputMetadata[outputTensorName].ElementType.ToString() + ".");
}
}
// 检查输入尺寸
if (inputTensorDimensions.Count != inputTensorNames.Length)
{
throw new ArgumentException("inputTensorDimensions", "The inputTensorDimensions and inputTensorNames should have the same length.");
}
for (int ni = 0; ni < inputTensorNames.Length; ni++)
{
var inputTensorDimension = inferSession.InputMetadata[inputTensorNames[ni]].Dimensions;
var inputTensorDimensionExpected = inputTensorDimensions[ni];
if (inputTensorDimension.Length != inputTensorDimensionExpected.Length)
{
throw new ArgumentException("inputTensorDimensionLength", "The expected input tensor dimension length is " +
inputTensorDimensionExpected.Length + ", but got " + inputTensorDimension.Length + ".");
}
if (inputTensorDimension.Length <= 1)
{
continue;
}
for (int nj = 0; nj < inputTensorDimension.Length; nj++)
{
if (inputTensorDimension[nj] != inputTensorDimensionExpected[nj] && inputTensorDimension[nj] != -1)
{
throw new ArgumentException("inputTensorDimension", "The expected input tensor dimension is " +
inputTensorDimensionExpected[nj] + ", but got " + inputTensorDimension[nj] + " for the " + nj + "th axis.");
}
}
}
// 检查输出尺寸
if (outputTensorDimensions.Count != outputTensorNames.Length)
{
throw new ArgumentException("outputTensorDimensions", "The outputTensorDimensions and outputTensorNames should have the same length.");
}
for (int ni = 0; ni < outputTensorNames.Length; ni++)
{
var outputTensorDimension = inferSession.OutputMetadata[outputTensorNames[ni]].Dimensions;
var outputTensorDimensionExpected = outputTensorDimensions[ni];
if (outputTensorDimension.Length != outputTensorDimensionExpected.Length)
{
throw new ArgumentException("outputTensorDimensionLength", "The expected output tensor dimension length is " +
outputTensorDimensionExpected.Length + ", but got " + outputTensorDimension.Length + ".");
}
if (outputTensorDimension.Length <= 1)
{
continue;
}
for (int nj = 0; nj < outputTensorDimension.Length; nj++)
{
if (outputTensorDimension[nj] != outputTensorDimensionExpected[nj] && outputTensorDimension[nj] != -1)
{
throw new ArgumentException("outputTensorDimension", "The expected output tensor dimension is" +
outputTensorDimensionExpected[nj] + ", but got " + outputTensorDimension[nj] + " for the " + nj + "th axis.");
}
}
}
}
#endregion
}
}