#include"InferNetOnnxPaddleOcrDetect.h"

InferNetOnnxPaddleOcrDetect::InferNetOnnxPaddleOcrDetect()
{
}

void InferNetOnnxPaddleOcrDetect::LoadNetwork(const void* modelData, size_t modelDataLen)
{
    if (_modelLoaded)
    {
        // 如果模型已加载,则释放之前的模型
        delete net;
        net = nullptr;
    }

    this->binaryThreshold = 0.3;
    this->polygonThreshold = 0.5;
    this->unclipRatio = 1.6;
    this->maxCandidates = 1000;

    sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);

    net = new Session(env, modelData, modelDataLen, sessionOptions);
    size_t numInputNodes = net->GetInputCount();
    size_t numOutputNodes = net->GetOutputCount();
    AllocatorWithDefaultOptions allocator;
    for (int i = 0; i < numInputNodes; i++)
    {
        inputNames.push_back(net->GetInputName(i, allocator));
    }
    for (int i = 0; i < numOutputNodes; i++)
    {
        outputNames.push_back(net->GetOutputName(i, allocator));
    }
    _modelLoaded = true;
}

std::vector<TextBlock> InferNetOnnxPaddleOcrDetect::Process(cv::Mat& srcimg)
{    // 对图像预处理
    //cv::Mat dstimg = this->preprocess(srcimg);
    cv::Mat dstimg =  srcimg.clone();
    this->normalize_(dstimg);

    // 创建用于存储输入形状的数组
    array<int64_t, 4> input_shape_{ 1, 3, dstimg.rows, dstimg.cols };
    // 创建CPU内存的分配器信息
    auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

    // 创建输入张量
    Value input_tensor_ = Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
    // 运行ONNX模型获取输出
    vector<Value> ort_outputs = net->Run(RunOptions{ nullptr }, &inputNames[0], &input_tensor_, 1, outputNames.data(), outputNames.size());

    const float* floatArray = ort_outputs[0].GetTensorMutableData<float>();
    int outputCount = 1;
    for (int i = 0; i < ort_outputs.at(0).GetTensorTypeAndShapeInfo().GetShape().size(); i++)
    {
        int dim = ort_outputs.at(0).GetTensorTypeAndShapeInfo().GetShape().at(i);
        outputCount *= dim;
    }
    Mat binary(dstimg.rows, dstimg.cols, CV_32FC1);
    memcpy(binary.data, floatArray, outputCount * sizeof(float));

    // 输出结果提取box
    std::vector<TextBlock> results;
    results = GetTextBoxes(binary, srcimg);

    return results;
}

// 该函数用于对输入图像进行预处理,包括颜色空间转换和图像缩放
cv::Mat InferNetOnnxPaddleOcrDetect::preprocess(cv::Mat srcimg)
{
    cv::Mat dstimg = srcimg.clone();
    //cv::Mat dstimg;

    int h = srcimg.rows;
    int w = srcimg.cols;
    // 初始化高度和宽度的缩放比例
    float scale_h = 1;
    float scale_w = 1;
    // 根据图像的高度和宽度选择缩放比例
    if (h < w)
    {
        // 如果图像高度小于宽度 计算高度缩放比例
        scale_h = (float)this->shortSize / (float)h;
        float tar_w = (float)w * scale_h;
        tar_w = tar_w - (int)tar_w % 32;
        tar_w = max((float)32, tar_w);
        scale_w = tar_w / (float)w;
    }
    else
    {
        // 如果图像宽度小于等于高度 计算宽度缩放比例
        scale_w = (float)this->shortSize / (float)w;
        float tar_h = (float)h * scale_w;
        tar_h = tar_h - (int)tar_h % 32;
        tar_h = max((float)32, tar_h);
        scale_h = tar_h / (float)h;
    }
    // 使用线性插值对图像进行缩放,以调整到目标尺寸
    resize(dstimg, dstimg, Size(int(scale_w * dstimg.cols), int(scale_h * dstimg.rows)), INTER_LINEAR);
    return dstimg;
}

void InferNetOnnxPaddleOcrDetect::normalize_(cv::Mat img)
{
    //    img.convertTo(img, CV_32F);
    int row = img.rows;
    int col = img.cols;
    this->input_image_.resize(row * col * img.channels());
    for (int c = 0; c < 3; c++)
    {
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < col; j++)
            {
                float pix = img.ptr<uchar>(i)[j * 3 + c];
                this->input_image_[c * row * col + i * col + j] = (pix / 255.0 - this->meanValues[c]) / this->normValues[c];
            }
        }
    }
}

Point2f InferNetOnnxPaddleOcrDetect::pointCenBoxs(vector<Point2f> polygonBoxs)
{
    // 计算中心点坐标
    Point2f center(0, 0);  // 初始化中心点坐标

    // 遍历四个点,累加坐标
    for (const auto& point : polygonBoxs) {
        center.x += point.x;
        center.y += point.y;
    }

    // 计算平均值
    center.x /= polygonBoxs.size();
    center.y /= polygonBoxs.size();

    return center;
}
std::vector<TextBlock> InferNetOnnxPaddleOcrDetect::GetTextBoxes(cv::Mat& binaryIN, cv::Mat& srcimgIN)
{
    // 获取图像的高度和宽度
    int h = srcimgIN.rows;
    int w = srcimgIN.cols;
    // 二值化处理
    Mat bitmap;
    threshold(binaryIN, bitmap, binaryThreshold, 255, THRESH_BINARY);
    //// 计算图像缩放比例
    float scaleHeight = (float)(h) / (float)(binaryIN.size[0]);
    float scaleWidth = (float)(w) / (float)(binaryIN.size[1]);
    // 寻找轮廓
    vector< vector<Point> > contours;
    bitmap.convertTo(bitmap, CV_8UC1);
    findContours(bitmap, contours, RETR_LIST, CHAIN_APPROX_SIMPLE);

    // 限制候选框的数量
    size_t numCandidate = min(contours.size(), (size_t)(maxCandidates > 0 ? maxCandidates : INT_MAX));
    vector<float> confidences;
    //vector< vector<Point2f> > rsBoxes;

    std::vector<TextBlock> rsBoxes;
    // 遍历每个候选框
    for (size_t i = 0; i < numCandidate; i++)
    {
        vector<Point>& contour = contours[i];

        // 计算文本轮廓分数
        float score = contourScore(binaryIN, contour);
        float boxScore = 0.0f;

        if (score < polygonThreshold) {
            boxScore = score;
            continue;
        }
        //// 对轮廓进行缩放
        vector<Point> contourScaled; contourScaled.reserve(contour.size());
        for (size_t j = 0; j < contour.size(); j++)
        {
            contourScaled.push_back(Point(int(contour[j].x * scaleWidth),
                int(contour[j].y * scaleHeight)));
        }
        // 检查坐标是否有效
        bool coordinatesValid = true;
        for (size_t j = 0; j < contourScaled.size(); j++) {
            if (contourScaled[j].x < 0 || contourScaled[j].y < 0 ||
                contourScaled[j].x >= w || contourScaled[j].y >= h) {
                coordinatesValid = false;
                break;
            }
        }

        // 如果坐标有效,则处理该结果
        if (coordinatesValid)
        {
            TextBlock detectedBox;
            // 解除裁剪
            RotatedRect box = minAreaRect(contourScaled);
            float longSide = std::max(box.size.width, box.size.height);
            if (longSide < longSideThresh)
            {
                continue;
            }

            // minArea() rect is not normalized, it may return rectangles with angle=-90 or height < width
            const float angle_threshold = 60;  // do not expect vertical text, TODO detection algo property
            bool swap_size = false;
            if (box.size.width < box.size.height)  // horizontal-wide text area is expected
                swap_size = true;
            else if (fabs(box.angle) >= angle_threshold)  // don't work with vertical rectangles
                swap_size = true;
            if (swap_size)
            {
                swap(box.size.width, box.size.height);
                if (box.angle < 0)
                    box.angle += 90;
                else if (box.angle > 0)
                    box.angle -= 90;
            }

            Point2f vertex[4];
            box.points(vertex);  // order: bl, tl, tr, br
            vector<Point2f> approx;
            for (int j = 0; j < 4; j++)
                approx.emplace_back(vertex[j]);
            vector<Point2f> polygon;
            unclip(approx, polygon);

            box = minAreaRect(polygon);
            longSide = std::max(box.size.width, box.size.height);
            if (longSide < longSideThresh + 2)
            {
                continue;
            }

            if (std::all_of(polygon.begin(), polygon.end(), [w, h](const Point2f& p) {
                return p.x >= 0 && p.x <= w && p.y >= 0 && p.y <= h;
                }))
            {
                Point2f centPoint = pointCenBoxs(polygon);

                detectedBox.boxVertices = polygon;
                detectedBox.boxCenterVer = centPoint;
                detectedBox.angle = box.angle;
                detectedBox.boxScore = boxScore;
                rsBoxes.push_back(detectedBox);
            }

        }
    }
    confidences = vector<float>(contours.size(), 1.0f);

    // 对 results 进行倒序处理
    std::reverse(rsBoxes.begin(), rsBoxes.end());

    return rsBoxes;
}

std::vector< std::vector<Point2f> > InferNetOnnxPaddleOcrDetect::order_points_clockwise(std::vector< std::vector<Point2f> > results)
{
    std::vector< std::vector<Point2f> > order_points(results);
    for (int i = 0; i < results.size(); i++)
    {
        float max_sum_pts = -10000;
        float min_sum_pts = 10000;
        float max_diff_pts = -10000;
        float min_diff_pts = 10000;

        int max_sum_pts_id = 0;
        int min_sum_pts_id = 0;
        int max_diff_pts_id = 0;
        int min_diff_pts_id = 0;
        for (int j = 0; j < 4; j++)
        {
            const float sum_pt = results[i][j].x + results[i][j].y;
            if (sum_pt > max_sum_pts)
            {
                max_sum_pts = sum_pt;
                max_sum_pts_id = j;
            }
            if (sum_pt < min_sum_pts)
            {
                min_sum_pts = sum_pt;
                min_sum_pts_id = j;
            }

            const float diff_pt = results[i][j].y - results[i][j].x;
            if (diff_pt > max_diff_pts)
            {
                max_diff_pts = diff_pt;
                max_diff_pts_id = j;
            }
            if (diff_pt < min_diff_pts)
            {
                min_diff_pts = diff_pt;
                min_diff_pts_id = j;
            }
        }
        order_points[i][0].x = results[i][min_sum_pts_id].x;
        order_points[i][0].y = results[i][min_sum_pts_id].y;
        order_points[i][2].x = results[i][max_sum_pts_id].x;
        order_points[i][2].y = results[i][max_sum_pts_id].y;

        order_points[i][1].x = results[i][min_diff_pts_id].x;
        order_points[i][1].y = results[i][min_diff_pts_id].y;
        order_points[i][3].x = results[i][max_diff_pts_id].x;
        order_points[i][3].y = results[i][max_diff_pts_id].y;
    }
    return order_points;
}

void InferNetOnnxPaddleOcrDetect::drawPred(cv::Mat& srcimg, std::vector< std::vector<Point2f> > results)
{
    for (int i = 0; i < results.size(); i++)
    {
        for (int j = 0; j < 4; j++)
        {
            circle(srcimg, Point((int)results[i][j].x, (int)results[i][j].y), 2, Scalar(0, 0, 255), -1);
            if (j < 3)
            {
                line(srcimg, Point((int)results[i][j].x, (int)results[i][j].y), Point((int)results[i][j + 1].x, (int)results[i][j + 1].y), Scalar(0, 255, 0));
            }
            else
            {
                line(srcimg, Point((int)results[i][j].x, (int)results[i][j].y), Point((int)results[i][0].x, (int)results[i][0].y), Scalar(0, 255, 0));
            }
        }
    }
}

// 该函数计算二进制图像中指定轮廓的分数
float InferNetOnnxPaddleOcrDetect::contourScore(cv::Mat& binary, std::vector<Point>& contour)
{
    // 计算轮廓的边界矩形
    Rect rect = boundingRect(contour);
    // 计算边界框在二进制图像中的有效范围
    int xmin = max(rect.x, 0);
    int xmax = min(rect.x + rect.width, binary.cols - 1);
    int ymin = max(rect.y, 0);
    int ymax = min(rect.y + rect.height, binary.rows - 1);
    // 提取二进制图像中边界框的ROI(感兴趣区域)
    cv::Mat binROI = binary(Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1));
    // 创建一个掩码,用于标识ROI中的像素
    cv::Mat mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8U);
    // 将轮廓中的点坐标调整为ROI内的坐标
    std::vector<Point> roiContour;
    for (size_t i = 0; i < contour.size(); i++) {
        Point pt = Point(contour[i].x - xmin, contour[i].y - ymin);
        roiContour.push_back(pt);
    }
    // 使用填充多边形函数将ROI内的轮廓标记为1
    std::vector<std::vector<Point>> roiContours = { roiContour };
    fillPoly(mask, roiContours, Scalar(1));
    // 计算ROI内二进制图像的均值,以掩码为权重
    float score = mean(binROI, mask).val[0];
    return score;
}

void InferNetOnnxPaddleOcrDetect::unclip(std::vector<Point2f>& inPoly, std::vector<Point2f>& outPoly)
{
    // 计算轮廓的面积
    float area = contourArea(inPoly);
    float length = arcLength(inPoly, true); // 计算轮廓的周长
    float distance = area * unclipRatio / length; // 计算解剪距离
    // 获取输入轮廓的点数
    size_t numPoints = inPoly.size();
    // 存储新的轮廓线段
    std::vector<std::vector<Point2f>> newLines;
    // 遍历原始轮廓的每个点
    for (size_t i = 0; i < numPoints; i++)
    {
        std::vector<Point2f> newLine;
        Point pt1 = inPoly[i];
        Point pt2 = inPoly[(i - 1) % numPoints];
        Point vec = pt1 - pt2;
        // 计算解剪距离
        float unclipDis = (float)(distance / norm(vec));
        // 计算旋转后的向量
        Point2f rotateVec = Point2f(vec.y * unclipDis, -vec.x * unclipDis);
        // 添加旋转后的点到新线段
        newLine.push_back(Point2f(pt1.x + rotateVec.x, pt1.y + rotateVec.y));
        newLine.push_back(Point2f(pt2.x + rotateVec.x, pt2.y + rotateVec.y));
        newLines.push_back(newLine);
    }
    // 获取新线段的数量
    size_t numLines = newLines.size();
    // 遍历新线段集合
    for (size_t i = 0; i < numLines; i++)
    {
        Point2f a = newLines[i][0];
        Point2f b = newLines[i][1];
        Point2f c = newLines[(i + 1) % numLines][0];
        Point2f d = newLines[(i + 1) % numLines][1];
        Point2f pt;
        // 计算两向量的夹角余弦值
        Point2f v1 = b - a;
        Point2f v2 = d - c;
        float cosAngle = (v1.x * v2.x + v1.y * v2.y) / (norm(v1) * norm(v2));
        // 根据夹角余弦值判断旋转后的点位置
        if (fabs(cosAngle) > 0.7)
        {
            pt.x = (b.x + c.x) * 0.5;
            pt.y = (b.y + c.y) * 0.5;
        }
        else
        {
            float denom = a.x * (float)(d.y - c.y) + b.x * (float)(c.y - d.y) +
                d.x * (float)(b.y - a.y) + c.x * (float)(a.y - b.y);
            float num = a.x * (float)(d.y - c.y) + c.x * (float)(a.y - d.y) + d.x * (float)(c.y - a.y);
            float s = num / denom;

            pt.x = a.x + s * (b.x - a.x);
            pt.y = a.y + s * (b.y - a.y);
        }
        // 将计算得到的点添加到输出轮廓
        outPoly.push_back(pt);
    }
}

cv::Mat InferNetOnnxPaddleOcrDetect::getRotateCropImage(cv::Mat& frame, std::vector<Point2f> vertices)
{
    // 计算包围轮廓的最小矩形
    Rect rect = boundingRect(cv::Mat(vertices));
    // 从原始图像中提取感兴趣区域(ROI)
    cv::Mat crop_img = frame(rect);

    // 设置输出图像的大小为矩形的宽度和高度
    const Size outputSize = Size(rect.width, rect.height);
    // 定义目标矩形的四个顶点坐标
    std::vector<Point2f> targetVertices{ Point2f(0, outputSize.height), Point2f(0, 0), Point2f(outputSize.width, 0), Point2f(outputSize.width, outputSize.height) };

    // 将原始轮廓的顶点坐标调整为在裁剪后的图像中的坐标
    for (int i = 0; i < 4; i++)
    {
        vertices[i].x -= rect.x;
        vertices[i].y -= rect.y;
    }

    // 计算透视变换矩阵,将原始轮廓映射到目标矩形
    cv::Mat rotationMatrix = cv::getPerspectiveTransform(vertices, targetVertices);

    // 应用透视变换,旋转和裁剪原始图像的感兴趣区域
    cv::Mat result;
    cv::warpPerspective(crop_img, result, rotationMatrix, outputSize, cv::BORDER_REPLICATE);

    return result;
}

void InferNetOnnxPaddleOcrDetect::Dispose()
{
    // 在此处释放资源,确保在对象销毁时调用
    //delete net;
    //net = nullptr;
}