get_model_erf.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. warnings.simplefilter('ignore')
  4. import torch, yaml, cv2, os, shutil, sys, glob
  5. import numpy as np
  6. np.random.seed(0)
  7. import matplotlib.pyplot as plt
  8. from tqdm import trange
  9. from PIL import Image
  10. from ultralytics.nn.tasks import attempt_load_weights
  11. from timm.utils import AverageMeter
  12. import matplotlib.pyplot as plt
  13. plt.rcParams["font.family"] = "Times New Roman"
  14. import seaborn as sns
  15. def get_activation(feat, backbone_idx=-1):
  16. def hook(model, inputs, outputs):
  17. if backbone_idx != -1:
  18. for _ in range(5 - len(outputs)): outputs.insert(0, None)
  19. feat.append(outputs[backbone_idx])
  20. else:
  21. feat.append(outputs)
  22. return hook
  23. def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  24. # Resize and pad image while meeting stride-multiple constraints
  25. shape = im.shape[:2] # current shape [height, width]
  26. if isinstance(new_shape, int):
  27. new_shape = (new_shape, new_shape)
  28. # Scale ratio (new / old)
  29. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  30. if not scaleup: # only scale down, do not scale up (for better val mAP)
  31. r = min(r, 1.0)
  32. # Compute padding
  33. ratio = r, r # width, height ratios
  34. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  35. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  36. if auto: # minimum rectangle
  37. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  38. elif scaleFill: # stretch
  39. dw, dh = 0.0, 0.0
  40. new_unpad = (new_shape[1], new_shape[0])
  41. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  42. dw /= 2 # divide padding into 2 sides
  43. dh /= 2
  44. if shape[::-1] != new_unpad: # resize
  45. im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
  46. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  47. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  48. im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  49. return im, ratio, (dw, dh)
  50. def get_rectangle(data, thresh):
  51. h, w = data.shape
  52. all_sum = np.sum(data)
  53. for i in range(1, h // 2):
  54. selected_area = data[h // 2 - i:h // 2 + 1 + i, w // 2 - i:w // 2 + 1 + i]
  55. area_sum = np.sum(selected_area)
  56. if area_sum / all_sum > thresh:
  57. return i * 2 + 1, (i * 2 + 1) / h * (i * 2 + 1) / w
  58. return None
  59. def heatmap(data, camp='RdYlGn', figsize=(10, 10.75), ax=None, save_path=None):
  60. plt.figure(figsize=figsize, dpi=40)
  61. ax = sns.heatmap(data,
  62. xticklabels=False,
  63. yticklabels=False, cmap=camp,
  64. center=0, annot=False, ax=ax, cbar=True, annot_kws={"size": 24}, fmt='.2f')
  65. plt.tight_layout()
  66. plt.savefig(save_path)
  67. class yolov8_erf:
  68. feature, hooks = [], []
  69. def __init__(self, weight, device, layer, dataset, num_images, save_path) -> None:
  70. device = torch.device(device)
  71. ckpt = torch.load(weight)
  72. model = attempt_load_weights(weight, device)
  73. model.info()
  74. for p in model.parameters():
  75. p.requires_grad_(True)
  76. model.eval()
  77. optimizer = torch.optim.SGD(model.parameters(), lr=0, weight_decay=0)
  78. meter = AverageMeter()
  79. optimizer.zero_grad()
  80. if '-' in layer:
  81. layer_first, layer_second = layer.split('-')
  82. self.hooks.append(model.model[int(layer_first)].register_forward_hook(get_activation(self.feature, backbone_idx=int(layer_second))))
  83. else:
  84. self.hooks.append(model.model[int(layer)].register_forward_hook(get_activation(self.feature)))
  85. self.__dict__.update(locals())
  86. def get_input_grad(self, samples):
  87. _ = self.model(samples)
  88. outputs = self.feature[-1]
  89. self.feature.clear()
  90. out_size = outputs.size()
  91. central_point = torch.nn.functional.relu(outputs[:, :, out_size[2] // 2, out_size[3] // 2]).sum()
  92. grad = torch.autograd.grad(central_point, samples)
  93. grad = grad[0]
  94. grad = torch.nn.functional.relu(grad)
  95. aggregated = grad.sum((0, 1))
  96. grad_map = aggregated.cpu().numpy()
  97. return grad_map
  98. def process(self):
  99. for image_path in os.listdir(self.dataset):
  100. if self.meter.count == self.num_images:
  101. break
  102. img = cv2.imread(f'{self.dataset}/{image_path}')
  103. img = letterbox(img, auto=False)[0]
  104. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  105. img = np.float32(img) / 255.0
  106. samples = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)
  107. samples.requires_grad = True
  108. self.optimizer.zero_grad()
  109. contribution_scores = self.get_input_grad(samples)
  110. if np.isnan(np.sum(contribution_scores)):
  111. print('got NAN, next image')
  112. continue
  113. else:
  114. print(f'{self.meter.count}/{self.num_images} calculate....')
  115. self.meter.update(contribution_scores)
  116. # Set figure parameters
  117. large = 24; med = 24; small = 24
  118. params = {'axes.titlesize': large,
  119. 'legend.fontsize': med,
  120. 'figure.figsize': (16, 10),
  121. 'axes.labelsize': med,
  122. 'xtick.labelsize': med,
  123. 'ytick.labelsize': med,
  124. 'figure.titlesize': large}
  125. plt.rcParams.update(params)
  126. plt.style.use('seaborn-whitegrid')
  127. sns.set_style("white")
  128. plt.rc('font', **{'family': 'Times New Roman'})
  129. plt.rcParams['axes.unicode_minus'] = False
  130. data = self.meter.avg
  131. print(f'max value:{np.max(data):.3f} min value:{np.min(data):.3f}')
  132. data = np.log10(data + 1) # the scores differ in magnitude. take the logarithm for better readability
  133. data = data / np.max(data) # rescale to [0,1] for the comparability among models
  134. print('======================= the high-contribution area ratio =====================')
  135. for thresh in [0.2, 0.3, 0.5, 0.99]:
  136. side_length, area_ratio = get_rectangle(data, thresh)
  137. print('thresh, rectangle side length, area ratio: ', thresh, side_length, area_ratio)
  138. heatmap(data, save_path=self.save_path)
  139. def get_params():
  140. params = {
  141. 'weight': 'yolov8n.pt', # 只需要指定权重即可
  142. 'device': 'cuda:0',
  143. 'layer': '10', # string
  144. 'dataset': '',
  145. 'num_images': 50,
  146. 'save_path': 'result.png'
  147. }
  148. return params
  149. if __name__ == '__main__':
  150. cfg = get_params()
  151. yolov8_erf(**cfg).process()