prompt.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import os
  3. from pathlib import Path
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from ultralytics.utils import TQDM, checks
  9. class FastSAMPrompt:
  10. """
  11. Fast Segment Anything Model class for image annotation and visualization.
  12. Attributes:
  13. device (str): Computing device ('cuda' or 'cpu').
  14. results: Object detection or segmentation results.
  15. source: Source image or image path.
  16. clip: CLIP model for linear assignment.
  17. """
  18. def __init__(self, source, results, device="cuda") -> None:
  19. """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
  20. if isinstance(source, (str, Path)) and os.path.isdir(source):
  21. raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
  22. self.device = device
  23. self.results = results
  24. self.source = source
  25. # Import and assign clip
  26. try:
  27. import clip
  28. except ImportError:
  29. checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
  30. import clip
  31. self.clip = clip
  32. @staticmethod
  33. def _segment_image(image, bbox):
  34. """Segments the given image according to the provided bounding box coordinates."""
  35. image_array = np.array(image)
  36. segmented_image_array = np.zeros_like(image_array)
  37. x1, y1, x2, y2 = bbox
  38. segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
  39. segmented_image = Image.fromarray(segmented_image_array)
  40. black_image = Image.new("RGB", image.size, (255, 255, 255))
  41. # transparency_mask = np.zeros_like((), dtype=np.uint8)
  42. transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
  43. transparency_mask[y1:y2, x1:x2] = 255
  44. transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
  45. black_image.paste(segmented_image, mask=transparency_mask_image)
  46. return black_image
  47. @staticmethod
  48. def _format_results(result, filter=0):
  49. """Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
  50. area.
  51. """
  52. annotations = []
  53. n = len(result.masks.data) if result.masks is not None else 0
  54. for i in range(n):
  55. mask = result.masks.data[i] == 1.0
  56. if torch.sum(mask) >= filter:
  57. annotation = {
  58. "id": i,
  59. "segmentation": mask.cpu().numpy(),
  60. "bbox": result.boxes.data[i],
  61. "score": result.boxes.conf[i],
  62. }
  63. annotation["area"] = annotation["segmentation"].sum()
  64. annotations.append(annotation)
  65. return annotations
  66. @staticmethod
  67. def _get_bbox_from_mask(mask):
  68. """Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
  69. contours.
  70. """
  71. mask = mask.astype(np.uint8)
  72. contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  73. x1, y1, w, h = cv2.boundingRect(contours[0])
  74. x2, y2 = x1 + w, y1 + h
  75. if len(contours) > 1:
  76. for b in contours:
  77. x_t, y_t, w_t, h_t = cv2.boundingRect(b)
  78. x1 = min(x1, x_t)
  79. y1 = min(y1, y_t)
  80. x2 = max(x2, x_t + w_t)
  81. y2 = max(y2, y_t + h_t)
  82. return [x1, y1, x2, y2]
  83. def plot(
  84. self,
  85. annotations,
  86. output,
  87. bbox=None,
  88. points=None,
  89. point_label=None,
  90. mask_random_color=True,
  91. better_quality=True,
  92. retina=False,
  93. with_contours=True,
  94. ):
  95. """
  96. Plots annotations, bounding boxes, and points on images and saves the output.
  97. Args:
  98. annotations (list): Annotations to be plotted.
  99. output (str or Path): Output directory for saving the plots.
  100. bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
  101. points (list, optional): Points to be plotted. Defaults to None.
  102. point_label (list, optional): Labels for the points. Defaults to None.
  103. mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
  104. better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
  105. Defaults to True.
  106. retina (bool, optional): Whether to use retina mask. Defaults to False.
  107. with_contours (bool, optional): Whether to plot contours. Defaults to True.
  108. """
  109. import matplotlib.pyplot as plt
  110. pbar = TQDM(annotations, total=len(annotations))
  111. for ann in pbar:
  112. result_name = os.path.basename(ann.path)
  113. image = ann.orig_img[..., ::-1] # BGR to RGB
  114. original_h, original_w = ann.orig_shape
  115. # For macOS only
  116. # plt.switch_backend('TkAgg')
  117. plt.figure(figsize=(original_w / 100, original_h / 100))
  118. # Add subplot with no margin.
  119. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  120. plt.margins(0, 0)
  121. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  122. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  123. plt.imshow(image)
  124. if ann.masks is not None:
  125. masks = ann.masks.data
  126. if better_quality:
  127. if isinstance(masks[0], torch.Tensor):
  128. masks = np.array(masks.cpu())
  129. for i, mask in enumerate(masks):
  130. mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
  131. masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
  132. self.fast_show_mask(
  133. masks,
  134. plt.gca(),
  135. random_color=mask_random_color,
  136. bbox=bbox,
  137. points=points,
  138. pointlabel=point_label,
  139. retinamask=retina,
  140. target_height=original_h,
  141. target_width=original_w,
  142. )
  143. if with_contours:
  144. contour_all = []
  145. temp = np.zeros((original_h, original_w, 1))
  146. for i, mask in enumerate(masks):
  147. mask = mask.astype(np.uint8)
  148. if not retina:
  149. mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
  150. contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  151. contour_all.extend(iter(contours))
  152. cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
  153. color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
  154. contour_mask = temp / 255 * color.reshape(1, 1, -1)
  155. plt.imshow(contour_mask)
  156. # Save the figure
  157. save_path = Path(output) / result_name
  158. save_path.parent.mkdir(exist_ok=True, parents=True)
  159. plt.axis("off")
  160. plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
  161. plt.close()
  162. pbar.set_description(f"Saving {result_name} to {save_path}")
  163. @staticmethod
  164. def fast_show_mask(
  165. annotation,
  166. ax,
  167. random_color=False,
  168. bbox=None,
  169. points=None,
  170. pointlabel=None,
  171. retinamask=True,
  172. target_height=960,
  173. target_width=960,
  174. ):
  175. """
  176. Quickly shows the mask annotations on the given matplotlib axis.
  177. Args:
  178. annotation (array-like): Mask annotation.
  179. ax (matplotlib.axes.Axes): Matplotlib axis.
  180. random_color (bool, optional): Whether to use random color for masks. Defaults to False.
  181. bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
  182. points (list, optional): Points to be plotted. Defaults to None.
  183. pointlabel (list, optional): Labels for the points. Defaults to None.
  184. retinamask (bool, optional): Whether to use retina mask. Defaults to True.
  185. target_height (int, optional): Target height for resizing. Defaults to 960.
  186. target_width (int, optional): Target width for resizing. Defaults to 960.
  187. """
  188. import matplotlib.pyplot as plt
  189. n, h, w = annotation.shape # batch, height, width
  190. areas = np.sum(annotation, axis=(1, 2))
  191. annotation = annotation[np.argsort(areas)]
  192. index = (annotation != 0).argmax(axis=0)
  193. if random_color:
  194. color = np.random.random((n, 1, 1, 3))
  195. else:
  196. color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
  197. transparency = np.ones((n, 1, 1, 1)) * 0.6
  198. visual = np.concatenate([color, transparency], axis=-1)
  199. mask_image = np.expand_dims(annotation, -1) * visual
  200. show = np.zeros((h, w, 4))
  201. h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
  202. indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
  203. show[h_indices, w_indices, :] = mask_image[indices]
  204. if bbox is not None:
  205. x1, y1, x2, y2 = bbox
  206. ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
  207. # Draw point
  208. if points is not None:
  209. plt.scatter(
  210. [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
  211. [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
  212. s=20,
  213. c="y",
  214. )
  215. plt.scatter(
  216. [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
  217. [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
  218. s=20,
  219. c="m",
  220. )
  221. if not retinamask:
  222. show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
  223. ax.imshow(show)
  224. @torch.no_grad()
  225. def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
  226. """Processes images and text with a model, calculates similarity, and returns softmax score."""
  227. preprocessed_images = [preprocess(image).to(device) for image in elements]
  228. tokenized_text = self.clip.tokenize([search_text]).to(device)
  229. stacked_images = torch.stack(preprocessed_images)
  230. image_features = model.encode_image(stacked_images)
  231. text_features = model.encode_text(tokenized_text)
  232. image_features /= image_features.norm(dim=-1, keepdim=True)
  233. text_features /= text_features.norm(dim=-1, keepdim=True)
  234. probs = 100.0 * image_features @ text_features.T
  235. return probs[:, 0].softmax(dim=0)
  236. def _crop_image(self, format_results):
  237. """Crops an image based on provided annotation format and returns cropped images and related data."""
  238. image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
  239. ori_w, ori_h = image.size
  240. annotations = format_results
  241. mask_h, mask_w = annotations[0]["segmentation"].shape
  242. if ori_w != mask_w or ori_h != mask_h:
  243. image = image.resize((mask_w, mask_h))
  244. cropped_boxes = []
  245. cropped_images = []
  246. not_crop = []
  247. filter_id = []
  248. for _, mask in enumerate(annotations):
  249. if np.sum(mask["segmentation"]) <= 100:
  250. filter_id.append(_)
  251. continue
  252. bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
  253. cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
  254. cropped_images.append(bbox) # save cropped image bbox
  255. return cropped_boxes, cropped_images, not_crop, filter_id, annotations
  256. def box_prompt(self, bbox):
  257. """Modifies the bounding box properties and calculates IoU between masks and bounding box."""
  258. if self.results[0].masks is not None:
  259. assert bbox[2] != 0 and bbox[3] != 0, "Bounding box width and height should not be zero"
  260. masks = self.results[0].masks.data
  261. target_height, target_width = self.results[0].orig_shape
  262. h = masks.shape[1]
  263. w = masks.shape[2]
  264. if h != target_height or w != target_width:
  265. bbox = [
  266. int(bbox[0] * w / target_width),
  267. int(bbox[1] * h / target_height),
  268. int(bbox[2] * w / target_width),
  269. int(bbox[3] * h / target_height),
  270. ]
  271. bbox[0] = max(round(bbox[0]), 0)
  272. bbox[1] = max(round(bbox[1]), 0)
  273. bbox[2] = min(round(bbox[2]), w)
  274. bbox[3] = min(round(bbox[3]), h)
  275. # IoUs = torch.zeros(len(masks), dtype=torch.float32)
  276. bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
  277. masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
  278. orig_masks_area = torch.sum(masks, dim=(1, 2))
  279. union = bbox_area + orig_masks_area - masks_area
  280. iou = masks_area / union
  281. max_iou_index = torch.argmax(iou)
  282. self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
  283. return self.results
  284. def point_prompt(self, points, pointlabel): # numpy
  285. """Adjusts points on detected masks based on user input and returns the modified results."""
  286. if self.results[0].masks is not None:
  287. masks = self._format_results(self.results[0], 0)
  288. target_height, target_width = self.results[0].orig_shape
  289. h = masks[0]["segmentation"].shape[0]
  290. w = masks[0]["segmentation"].shape[1]
  291. if h != target_height or w != target_width:
  292. points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
  293. onemask = np.zeros((h, w))
  294. for annotation in masks:
  295. mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
  296. for i, point in enumerate(points):
  297. if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
  298. onemask += mask
  299. if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
  300. onemask -= mask
  301. onemask = onemask >= 1
  302. self.results[0].masks.data = torch.tensor(np.array([onemask]))
  303. return self.results
  304. def text_prompt(self, text):
  305. """Processes a text prompt, applies it to existing results and returns the updated results."""
  306. if self.results[0].masks is not None:
  307. format_results = self._format_results(self.results[0], 0)
  308. cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
  309. clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
  310. scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
  311. max_idx = scores.argsort()
  312. max_idx = max_idx[-1]
  313. max_idx += sum(np.array(filter_id) <= int(max_idx))
  314. self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
  315. return self.results
  316. def everything_prompt(self):
  317. """Returns the processed results from the previous methods in the class."""
  318. return self.results