神经网络已经在很多场景下表现出了很好的识别能力,但是缺乏解释性一直所为人诟病。《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》这篇论文基于梯度为其可解释性做了一些工作,它可以显著描述哪块图片区域对识别起了至关重要的作用,以热度图的方式可视化神经网络的注意力。本博客主要是基于pytorch的简单工程复现。原文见这里,本代码基于这里

  1 import torch
2 import torchvision
3 from torchvision import models
4 from torchvision import transforms
5 from PIL import Image
6 import pylab as plt
7 import numpy as np
8 import cv2
9
10
11 class Extractor():
12 """
13 pytorch在设计时,中间层的梯度完成回传后就释放了
14 这里用hook工具在保存中间参数的梯度
15 """
16 def __init__(self, model, target_layer):
17 self.model = model
18 self.target_layer = target_layer
19 self.gradient = None
20
21 def save_gradient(self, grad):
22 self.gradient=grad
23
24 def __call__(self, x):
25 outputs = []
26 self.gradients = []
27 for name,module in self.model.features._modules.items():
28 x = module(x)
29 if name == self.target_layer:
30 x.register_hook(self.save_gradient)
31 target_activation=x
32 x=x.view(1,-1)
33 for name,module in self.model.classifier._modules.items():
34 x = module(x)
35 # 维度为(1,c, h, w) , (1,class_num)
36 return target_activation, x
37
38
39 def preprocess_image(path):
40 means=[0.485, 0.456, 0.406]
41 stds=[0.229, 0.224, 0.225]
42 m_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(means,stds)])
43 img=Image.open(path)
44 return m_transform(img).reshape(1,3,224,224)
45
46
47 class GradCam():
48 def __init__(self, model, target_layer_name, use_cuda):
49 self.model = model
50 self.model.eval()
51 self.cuda = use_cuda
52 if self.cuda:
53 self.model = model.cuda()
54
55 self.extractor = Extractor(self.model, target_layer_name)
56
57
58 def __call__(self, input, index = None):
59 if self.cuda:
60 target_activation, output = self.extractor(input.cuda())
61 else:
62 target_activation, output = self.extractor(input)
63
64 # index是想要查看的类别,未指定时选择网络做出的预测类
65 if index == None:
66 index = np.argmax(output.cpu().data.numpy())
67
68 # batch维为1(我们默认输入的是单张图)
69 one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
70 one_hot[0][index] = 1.0
71 one_hot = torch.tensor(one_hot)
72 if self.cuda:
73 one_hot = torch.sum(one_hot.cuda() * output)
74 else:
75 one_hot = torch.sum(one_hot * output)
76
77 self.model.zero_grad()
78 one_hot.backward(retain_graph=True)
79
80 grads_val = self.extractor.gradient.cpu().data.numpy()
81 # 维度为(c, h, w)
82 target = target_activation.cpu().data.numpy()[0]
83 # 维度为(c,)
84 weights = np.mean(grads_val, axis = (2, 3))[0, :]
85 # cam要与target一样大
86 cam = np.zeros(target.shape[1 : ], dtype = np.float32)
87 for i, w in enumerate(weights):
88 cam += w * target[i, :, :]
89
90 # 每个位置选择c个通道上最大的最为输出
91 cam = np.maximum(cam, 0)
92 cam = cv2.resize(cam, (224, 224))
93 cam = cam - np.min(cam)
94 cam = cam / np.max(cam)
95 return cam
96
97
98 def show_cam_on_image(img, mask):
99 heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
100 heatmap = np.float32(heatmap) / 255
101 cam = heatmap + np.float32(img)
102 cam = cam / np.max(cam)
103 cv2.imwrite("cam2.jpg", np.uint8(255 * cam))
104
105
106 #target_layer 越靠近分类层效果越好
107 grad_cam = GradCam(model = models.vgg19(pretrained=True), target_layer_name = "35", use_cuda=True)
108 input = preprocess_image("both.png")
109 mask = grad_cam(input, None)
110 img = cv2.imread("both.png", 1)
111 #热度图是直接resize加到输入图上的
112 img = np.float32(cv2.resize(img, (224, 224))) / 255
113 show_cam_on_image(img, mask)

原图:

可视化图:

最新文章

  1. 领域驱动设计系列 (六):CQRS
  2. 介绍MFSideMenu左右滑动控件的使用
  3. 国内技术管理人员批阅google的“春运交通图”项目(大公司下的高效率)<转载>
  4. 关于HttpHandler的相关知识总结
  5. JavaScript 入门 (1)
  6. host DNS 访问规则
  7. codeforces George and Job
  8. [iOS]如何删除工程里面用cocoapods导入的第三方库
  9. 关于div的居中的问题
  10. solr使用方法 完全匹配
  11. 1008: University
  12. Find and run the whalesay image
  13. JS 从一个字符串中截取两个字符串之间的字符串
  14. 10_Eclipse中演示Git冲突的解决
  15. 我的csdn博客搬家了
  16. mongodb的TTL索引介绍(超时索引)
  17. 使用责任链模式消除if分支实践
  18. navicat导入sql文件错误
  19. Linux一键安装宝塔控制面板
  20. gogs 安装

热门文章

  1. Android四大组件——Activity——Activity数据回传
  2. IO ——字节流
  3. python基础练习题(题目 三数排序。)
  4. Python中的Super详解
  5. Blazor 组件库 BootstrapBlazor 中Editor组件介绍
  6. Unity—TextMeshPro
  7. javascript生成一棵树
  8. 一文讲透APaaS平台是什么
  9. fedora使用root超级用户
  10. JuiceFS v1.0 beta3 发布,支持 etcd、Amazon MemoryDB、Redis Cluster