masks.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import cv2
  2. import numpy as np
  3. def unmold_mask(mask, bbox, image_shape):
  4. """
  5. 将mask缩放到图像坐标系
  6. :param mask: [height, width] float型,通常较小
  7. :param bbox: [y1, x1, y2, x2] mask所在的box
  8. :param image_shape: 所需匹配的图像坐标
  9. :return: binary mask(图像尺寸为image_shape指定)
  10. """
  11. threshold = 0.5
  12. y1, x1, y2, x2 = bbox
  13. mask = cv2.resize(mask, (x2-x1, y2-y1))
  14. mask = np.where(mask >= threshold, 1, 0).astype(np.bool)
  15. # 将mask 放在正确的位置
  16. full_mask = np.zeros((int(image_shape[0]), int(image_shape[1])), dtype=np.bool)
  17. full_mask[y1:y2, x1:x2] = mask
  18. return full_mask
  19. def get_bounding_boxes(masks):
  20. num_masks = masks.shape[-1]
  21. bounding_boxes = np.zeros([num_masks, 4], dtype=np.int32)
  22. for i in range(num_masks):
  23. m = masks[:, :, i]
  24. horizontal_indicies = np.where(np.any(m, axis=0))[0]
  25. vertical_indicies = np.where(np.any(m, axis=1))[0]
  26. if horizontal_indicies.shape[0]:
  27. x1, x2 = horizontal_indicies[[0, -1]]
  28. y1, y2 = vertical_indicies[[0, -1]]
  29. x2 += 1
  30. y2 += 1
  31. else:
  32. x1, x2, y1, y2 = 0, 0, 0, 0
  33. bounding_boxes[i] = np.array([y1, x1, y2, x2])
  34. return bounding_boxes
  35. def compute_overlaps_masks(masks1, masks2):
  36. """
  37. 计算两组masks的重叠率
  38. :param masks1: [height, width, num_instance_1]
  39. :param masks2: [height, width, num_instance_2]
  40. :return:
  41. """
  42. # 任意一个输入为空,则返回结果为空
  43. if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
  44. return np.zeros((masks1.shape[-1], masks2.shape[-2]))
  45. # 展平,计算面积
  46. masks1 = np.reshape(masks1 > .5, (-1, masks1.shape[-1])).astype(np.float32)
  47. masks2 = np.reshape(masks2 > .5, (-1, masks2.shape[-1])).astype(np.float32)
  48. area1 = np.sum(masks1, axis=0)
  49. area2 = np.sum(masks2, axis=0)
  50. # 计算IoU
  51. intersections = np.dot(masks1.T, masks2)
  52. union = area1[:, None] + area2[None, :] - intersections
  53. overlaps = intersections / union
  54. return overlaps
  55. def compute_ratio_masks_in_rect(masks, rect):
  56. """
  57. 计算每一个mask在rect中的面积比
  58. :param masks:
  59. :param rect: 未归一化
  60. :return:
  61. """
  62. num_mask = masks.shape[-1]
  63. ratios = np.zeros(num_mask, dtype=np.float32)
  64. for i in range(num_mask):
  65. area_mask = np.sum(masks[:, :, i])
  66. if area_mask <= 0:
  67. ratios[i] = 0
  68. area_in_rect = np.sum(masks[rect[1]: rect[3], rect[0]: rect[2], i])
  69. ratios[i] = area_in_rect/float(area_mask)
  70. return ratios
  71. def masks_center_in_rect(masks, rect):
  72. """
  73. 判断每一个mask的中心是否在rect以内
  74. :param masks:
  75. :param rect: 未归一化
  76. :return:
  77. """
  78. num_mask = masks.shape[-1]
  79. in_rect = np.zeros(num_mask, dtype=np.bool)
  80. for i in range(num_mask):
  81. inds = np.where(masks[:, :, i])
  82. if len(inds[0]) > 0:
  83. center_y = np.mean(inds[0])
  84. center_x = np.mean(inds[1])
  85. if (center_y >= rect[1]) and (center_y <= rect[3])\
  86. and (center_x >= rect[0]) and (center_x <= rect[2]):
  87. in_rect[i] = True
  88. return in_rect
  89. def clip_masks(masks, class_ids, clip_rect, in_rect=None):
  90. """
  91. 如果mask的中心不在clip_rect范围内,则直接去掉该标注,如果在clip_rect范围内,则只裁取clip_rect范围内的部分
  92. :param masks:
  93. :param class_ids:
  94. :param clip_rect:
  95. :param in_rect:
  96. :return:
  97. """
  98. if in_rect is None:
  99. in_rect = masks_center_in_rect(masks, clip_rect)
  100. num_mask = masks.shape[-1]
  101. clip_rect_height = clip_rect[3] - clip_rect[1]
  102. clip_rect_width = clip_rect[2] - clip_rect[0]
  103. clipped_masks = np.zeros((clip_rect_height, clip_rect_width, num_mask), dtype=masks.dtype)
  104. clipped_class_ids = np.zeros_like(class_ids)
  105. clipped_ind = 0
  106. for i in range(num_mask):
  107. if in_rect[i]:
  108. clipped_masks[:, :, clipped_ind] = masks[clip_rect[1]: clip_rect[3], clip_rect[0]: clip_rect[2], i]
  109. clipped_class_ids[clipped_ind] = class_ids[i]
  110. clipped_ind += 1
  111. return clipped_masks, clipped_class_ids
  112. def minimize_mask(bbox, mask, mini_shape):
  113. """Resize masks to a smaller version to reduce memory load.
  114. Mini-masks can be resized back to image scale using expand_masks()
  115. :param bbox: [num_gt, (y1, x1, y2, x2)] int32
  116. :param mask: [resize_height, resize_width, num_gt] uint8
  117. :param mini_shape: [mini_mask_height, mini_mask_width]
  118. """
  119. # H W C
  120. mini_height = mini_shape[0]
  121. mini_width = mini_shape[1]
  122. num_gt = mask.shape[-1]
  123. mini_mask = np.zeros((mini_height, mini_width, num_gt), dtype=np.uint8)
  124. for i in range(num_gt):
  125. m = mask[:, :, i].astype(np.uint8)
  126. y1, x1, y2, x2 = bbox[i][:4]
  127. m = m[y1:y2, x1:x2]
  128. if m.size == 0:
  129. m = np.zeros((mini_height, mini_width), dtype=np.uint8)
  130. else:
  131. m = cv2.resize(m, (mini_width, mini_height), 0, 0, interpolation=cv2.INTER_LINEAR)
  132. mini_mask[:, :, i] = np.around(m)
  133. return mini_mask