heatmap.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. warnings.simplefilter('ignore')
  4. import torch, yaml, cv2, os, shutil, sys
  5. import numpy as np
  6. np.random.seed(0)
  7. import matplotlib.pyplot as plt
  8. from tqdm import trange
  9. from PIL import Image
  10. from ultralytics.nn.tasks import attempt_load_weights
  11. from ultralytics.utils.torch_utils import intersect_dicts
  12. from ultralytics.utils.ops import xywh2xyxy, non_max_suppression
  13. from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
  14. from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
  15. from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
  16. def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  17. # Resize and pad image while meeting stride-multiple constraints
  18. shape = im.shape[:2] # current shape [height, width]
  19. if isinstance(new_shape, int):
  20. new_shape = (new_shape, new_shape)
  21. # Scale ratio (new / old)
  22. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  23. if not scaleup: # only scale down, do not scale up (for better val mAP)
  24. r = min(r, 1.0)
  25. # Compute padding
  26. ratio = r, r # width, height ratios
  27. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  28. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  29. if auto: # minimum rectangle
  30. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  31. elif scaleFill: # stretch
  32. dw, dh = 0.0, 0.0
  33. new_unpad = (new_shape[1], new_shape[0])
  34. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  35. dw /= 2 # divide padding into 2 sides
  36. dh /= 2
  37. if shape[::-1] != new_unpad: # resize
  38. im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
  39. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  40. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  41. im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  42. return im, ratio, (dw, dh)
  43. class ActivationsAndGradients:
  44. """ Class for extracting activations and
  45. registering gradients from targetted intermediate layers """
  46. def __init__(self, model, target_layers, reshape_transform):
  47. self.model = model
  48. self.gradients = []
  49. self.activations = []
  50. self.reshape_transform = reshape_transform
  51. self.handles = []
  52. for target_layer in target_layers:
  53. self.handles.append(
  54. target_layer.register_forward_hook(self.save_activation))
  55. # Because of https://github.com/pytorch/pytorch/issues/61519,
  56. # we don't use backward hook to record gradients.
  57. self.handles.append(
  58. target_layer.register_forward_hook(self.save_gradient))
  59. def save_activation(self, module, input, output):
  60. activation = output
  61. if self.reshape_transform is not None:
  62. activation = self.reshape_transform(activation)
  63. self.activations.append(activation.cpu().detach())
  64. def save_gradient(self, module, input, output):
  65. if not hasattr(output, "requires_grad") or not output.requires_grad:
  66. # You can only register hooks on tensor requires grad.
  67. return
  68. # Gradients are computed in reverse order
  69. def _store_grad(grad):
  70. if self.reshape_transform is not None:
  71. grad = self.reshape_transform(grad)
  72. self.gradients = [grad.cpu().detach()] + self.gradients
  73. output.register_hook(_store_grad)
  74. def post_process(self, result):
  75. if not self.model.end2end:
  76. logits_ = result[:, 4:]
  77. boxes_ = result[:, :4]
  78. sorted, indices = torch.sort(logits_.max(1)[0], descending=True)
  79. return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]
  80. else:
  81. logits_ = result[:, :, 4:]
  82. boxes_ = result[:, :, :4]
  83. sorted, indices = torch.sort(logits_[:, :, 0], descending=True)
  84. return logits_[0][indices[0]], boxes_[0][indices[0]]
  85. def __call__(self, x):
  86. self.gradients = []
  87. self.activations = []
  88. model_output = self.model(x)
  89. post_result, pre_post_boxes = self.post_process(model_output[0])
  90. return [[post_result, pre_post_boxes]]
  91. def release(self):
  92. for handle in self.handles:
  93. handle.remove()
  94. class yolov8_target(torch.nn.Module):
  95. def __init__(self, ouput_type, conf, ratio, end2end) -> None:
  96. super().__init__()
  97. self.ouput_type = ouput_type
  98. self.conf = conf
  99. self.ratio = ratio
  100. self.end2end = end2end
  101. def forward(self, data):
  102. post_result, pre_post_boxes = data
  103. result = []
  104. for i in trange(int(post_result.size(0) * self.ratio)):
  105. if (self.end2end and float(post_result[i, 0]) < self.conf) or (not self.end2end and float(post_result[i].max()) < self.conf):
  106. break
  107. if self.ouput_type == 'class' or self.ouput_type == 'all':
  108. if self.end2end:
  109. result.append(post_result[i, 0])
  110. else:
  111. result.append(post_result[i].max())
  112. elif self.ouput_type == 'box' or self.ouput_type == 'all':
  113. for j in range(4):
  114. result.append(pre_post_boxes[i, j])
  115. return sum(result)
  116. class yolov8_heatmap:
  117. def __init__(self, weight, device, method, layer, backward_type, conf_threshold, ratio, show_box, renormalize):
  118. device = torch.device(device)
  119. ckpt = torch.load(weight)
  120. model_names = ckpt['model'].names
  121. model = attempt_load_weights(weight, device)
  122. model.info()
  123. for p in model.parameters():
  124. p.requires_grad_(True)
  125. model.eval()
  126. target = yolov8_target(backward_type, conf_threshold, ratio, model.end2end)
  127. target_layers = [model.model[l] for l in layer]
  128. method = eval(method)(model, target_layers, use_cuda=device.type == 'cuda')
  129. method.activations_and_grads = ActivationsAndGradients(model, target_layers, None)
  130. colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int32)
  131. self.__dict__.update(locals())
  132. def post_process(self, result):
  133. result = non_max_suppression(result, conf_thres=self.conf_threshold, iou_thres=0.65)[0]
  134. return result
  135. def draw_detections(self, box, color, name, img):
  136. xmin, ymin, xmax, ymax = list(map(int, list(box)))
  137. cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
  138. cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)
  139. return img
  140. def renormalize_cam_in_bounding_boxes(self, boxes, image_float_np, grayscale_cam):
  141. """Normalize the CAM to be in the range [0, 1]
  142. inside every bounding boxes, and zero outside of the bounding boxes. """
  143. renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
  144. for x1, y1, x2, y2 in boxes:
  145. x1, y1 = max(x1, 0), max(y1, 0)
  146. x2, y2 = min(grayscale_cam.shape[1] - 1, x2), min(grayscale_cam.shape[0] - 1, y2)
  147. renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
  148. renormalized_cam = scale_cam_image(renormalized_cam)
  149. eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True)
  150. return eigencam_image_renormalized
  151. def process(self, img_path, save_path):
  152. # img process
  153. img = cv2.imread(img_path)
  154. img = letterbox(img)[0]
  155. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  156. img = np.float32(img) / 255.0
  157. tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)
  158. try:
  159. grayscale_cam = self.method(tensor, [self.target])
  160. except AttributeError as e:
  161. return
  162. grayscale_cam = grayscale_cam[0, :]
  163. cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
  164. pred = self.model(tensor)[0]
  165. if not self.model.end2end:
  166. pred = self.post_process(pred)
  167. else:
  168. pred = pred[0][pred[0, :, 4] > self.conf_threshold]
  169. if self.renormalize:
  170. cam_image = self.renormalize_cam_in_bounding_boxes(pred[:, :4].cpu().detach().numpy().astype(np.int32), img, grayscale_cam)
  171. if self.show_box:
  172. for data in pred:
  173. data = data.cpu().detach().numpy()
  174. cam_image = self.draw_detections(data[:4], self.colors[int(data[5])], f'{self.model_names[int(data[5])]} {float(data[4]):.2f}', cam_image)
  175. cam_image = Image.fromarray(cam_image)
  176. cam_image.save(save_path)
  177. def __call__(self, img_path, save_path):
  178. # remove dir if exist
  179. if os.path.exists(save_path):
  180. shutil.rmtree(save_path)
  181. # make dir if not exist
  182. os.makedirs(save_path, exist_ok=True)
  183. if os.path.isdir(img_path):
  184. for img_path_ in os.listdir(img_path):
  185. self.process(f'{img_path}/{img_path_}', f'{save_path}/{img_path_}')
  186. else:
  187. self.process(img_path, f'{save_path}/result.png')
  188. def get_params():
  189. params = {
  190. 'weight': 'runs/train/exp/weights/best.pt', # 现在只需要指定权重即可,不需要指定cfg
  191. 'device': 'cuda:0',
  192. 'method': 'HiResCAM', # GradCAMPlusPlus, GradCAM, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
  193. 'layer': [10, 12, 14, 16, 18],
  194. 'backward_type': 'class', # class, box, all
  195. 'conf_threshold': 0.2, # 0.2
  196. 'ratio': 0.02, # 0.02-0.1
  197. 'show_box': True,
  198. 'renormalize': False
  199. }
  200. return params
  201. if __name__ == '__main__':
  202. model = yolov8_heatmap(**get_params())
  203. # model(r'/root/dataset/dataset_visdrone/VisDrone2019-DET-test-dev/images/9999963_00000_d_0000020.jpg', 'result')
  204. model(r'/root/dataset/dataset_visdrone/VisDrone2019-DET-test-dev/images', 'result')