InferNetOnnxPaddleOcrCrnnRegC.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #include"InferNetOnnxPaddleOcrCrnnRegC.h"
  2. InferNetOnnxPaddleOcrCrnnReg::InferNetOnnxPaddleOcrCrnnReg()
  3. {
  4. }
  5. void InferNetOnnxPaddleOcrCrnnReg::LoadNetwork(const void* modelDataRec, size_t modelDataLengthRec, const void* modelDataKeys, size_t modelDataLengthKeys)
  6. {
  7. if (_modelLoadedRec)
  8. {
  9. // 如果模型已加载,则释放之前的模型
  10. delete ort_session;
  11. ort_session = nullptr;
  12. }
  13. sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
  14. ort_session = new Session(env, modelDataRec, modelDataLengthRec, sessionOptions);
  15. size_t numInputNodes = ort_session->GetInputCount();
  16. size_t numOutputNodes = ort_session->GetOutputCount();
  17. AllocatorWithDefaultOptions allocator;
  18. for (int i = 0; i < numInputNodes; i++)
  19. {
  20. inputNames.push_back(ort_session->GetInputName(i, allocator));
  21. Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
  22. auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
  23. auto input_dims = input_tensor_info.GetShape();
  24. inputNodeDims.push_back(input_dims);
  25. }
  26. for (int i = 0; i < numOutputNodes; i++)
  27. {
  28. outputNames.push_back(ort_session->GetOutputName(i, allocator));
  29. Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
  30. auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
  31. auto output_dims = output_tensor_info.GetShape();
  32. outputNodeDims.push_back(output_dims);
  33. }
  34. // 将字节数据转换为字符串
  35. std::string text(reinterpret_cast<const char*>(modelDataKeys), modelDataLengthKeys);
  36. // 使用字符串流处理字符串
  37. std::istringstream iss(text);
  38. std::string line;
  39. // 逐行读取并添加到 alphabet 中
  40. while (std::getline(iss, line))
  41. {
  42. this->alphabet.push_back(line);
  43. }
  44. this->alphabet.push_back(" ");
  45. names_len = this->alphabet.size();
  46. _modelLoadedRec = true;
  47. }
  48. std::string InferNetOnnxPaddleOcrCrnnReg::Process(cv::Mat& srcimgCv)
  49. {
  50. // 预处理图像
  51. cv::Mat dstimg = this->preprocess(srcimgCv);
  52. this->normalize_(dstimg); // 归一化图像
  53. // 定义输入张量的形状
  54. array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };
  55. // 创建 Ort 内存分配器
  56. auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  57. // 创建输入张量
  58. Value input_tensor_ = Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
  59. // 开始推理
  60. std::vector<Value> ortOutputs = ort_session->Run(RunOptions{ nullptr }, &inputNames[0], &input_tensor_, 1, outputNames.data(), outputNames.size()); // 开始推理
  61. // 获取输出数据指针
  62. float* pdata = ortOutputs[0].GetTensorMutableData<float>();
  63. // 获取输出图像的高度和宽度
  64. int h = ortOutputs.at(0).GetTensorTypeAndShapeInfo().GetShape().at(2);
  65. int w = ortOutputs.at(0).GetTensorTypeAndShapeInfo().GetShape().at(1);
  66. // 存储预测的标签
  67. prebLabel.resize(w);
  68. string results;
  69. results = PostProcess(w, h, pdata);
  70. return results;
  71. }
  72. string InferNetOnnxPaddleOcrCrnnReg::PostProcess(int wIn, int hIn, float* pdataIn)
  73. {
  74. int i = 0, j = 0;
  75. // 遍历输出,获取每列的最大值的索引作为标签
  76. for (i = 0; i < wIn; i++)
  77. {
  78. int one_label_idx = 0;
  79. float max_data = -10000;
  80. for (j = 0; j < hIn; j++)
  81. {
  82. float data_ = pdataIn[i * hIn + j];
  83. if (data_ > max_data)
  84. {
  85. max_data = data_;
  86. one_label_idx = j;
  87. }
  88. }
  89. prebLabel[i] = one_label_idx;
  90. }
  91. // 存储去重后的非空白标签
  92. std::vector<int> no_repeat_blank_label;
  93. for (size_t elementIndex = 0; elementIndex < wIn; ++elementIndex)
  94. {
  95. if (prebLabel[elementIndex] != 0 && !(elementIndex > 0 && prebLabel[elementIndex - 1] == prebLabel[elementIndex]))
  96. {
  97. no_repeat_blank_label.push_back(prebLabel[elementIndex] - 1);
  98. }
  99. }
  100. // 构建最终的预测文本
  101. int len_s = no_repeat_blank_label.size();
  102. std::string plate_text;
  103. for (i = 0; i < len_s; i++)
  104. {
  105. plate_text += alphabet[no_repeat_blank_label[i]];
  106. }
  107. return plate_text;
  108. }
  109. cv::Mat InferNetOnnxPaddleOcrCrnnReg::preprocess(cv::Mat srcimg)
  110. {
  111. cv::Mat dstimg;
  112. int h = srcimg.rows;
  113. int w = srcimg.cols;
  114. const float ratio = w / float(h);
  115. int resized_w = int(ceil((float)this->inpHeight * ratio));
  116. if (ceil(this->inpHeight * ratio) > this->inpWidth)
  117. {
  118. resized_w = this->inpWidth;
  119. }
  120. resize(srcimg, dstimg, Size(resized_w, this->inpHeight), INTER_LINEAR);
  121. return dstimg;
  122. }
  123. void InferNetOnnxPaddleOcrCrnnReg::normalize_(cv::Mat img)
  124. {
  125. //img.convertTo(img, CV_32F);
  126. int row = img.rows;
  127. int col = img.cols;
  128. this->input_image_.resize(this->inpHeight * this->inpWidth * img.channels());
  129. for (int c = 0; c < 3; c++)
  130. {
  131. for (int i = 0; i < row; i++)
  132. {
  133. for (int j = 0; j < inpWidth; j++)
  134. {
  135. if (j < col)
  136. {
  137. float pix = img.ptr<uchar>(i)[j * 3 + c];
  138. this->input_image_[c * row * inpWidth + i * inpWidth + j] = (pix / 255.0 - 0.5) / 0.5;
  139. }
  140. else
  141. {
  142. this->input_image_[c * row * inpWidth + i * inpWidth + j] = 0;
  143. }
  144. }
  145. }
  146. }
  147. }
  148. void InferNetOnnxPaddleOcrCrnnReg::Dispose()
  149. {
  150. // 释放 ort_session 对象
  151. if (ort_session != nullptr)
  152. {
  153. delete ort_session;
  154. ort_session = nullptr;
  155. }
  156. }