如何生成有一定可解释性的Attention Map

参考Github Repo: jeonsworld/ViT-pytorch

我们先试着跑一跑仓库里的train.py,本人写了一个脚本方便快速运行,作者也提供了预训练权重(.npz)。这里以ViT-B-16的模型为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
python train.py \
--dataset cifar10 \
--model_type ViT-B_16 \
--pretrained_dir "checkpoint/imagenet21k+imagenet2012_ViT-B_16.npz" \
--output_dir "./outputs" \
--img_size 224 \
--train_batch_size 256 \
--eval_batch_size 4 \
--name "exp_1" \
--eval_every 100 \
--learning_rate 3e-2 \
--weight_decay 0 \
--num_steps 10000 \
--decay_type cosine \
--warmup_steps 500 \
--max_grad_norm 1.0 \
--local_rank -1 \
--seed 42 \
--gradient_accumulation_steps 1 \
--fp16 \
--fp16_opt_level O2 \
--loss_scale 0

实践证明是能跑的,成功了第一步。

下面就着重研究如何可视化.

首先放出可视化脚本的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import typing
import io
import os

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

from urllib.request import urlretrieve

from PIL import Image
from torchvision import transforms

from models.modeling import VisionTransformer, CONFIGS

os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

# # Test Image
# img_url = "https://images.mypetlife.co.kr/content/uploads/2019/04/09192811/welsh-corgi-1581119_960_720.jpg"
# urlretrieve(img_url, "attention_data/img.jpg")

# Prepare Model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
model.load_from(np.load("attention_data/ViT-B_16-224.npz"))
model.eval()

transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
im = Image.open("/data/dataset/ImageNet/train/n01440764/n01440764_10026.JPEG")
x = transform(im)

logits, att_mat = model(x.unsqueeze(0))

att_mat = torch.stack(att_mat).squeeze(1) # torch.stack()进行扩维拼接

# Average the attention weights across all heads.
att_mat = torch.mean(att_mat, dim=1)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]

for n in range(1, aug_att_mat.size(0)):
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))

ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(im)
_ = ax2.imshow(result)

fig.savefig('output_image.png')

probs = torch.nn.Softmax(dim=-1)(logits)
top5 = torch.argsort(probs, dim=-1, descending=True)
print("Prediction Label and Attention Map!\n")
for idx in top5[0, :5]:
print(f'{probs[0, idx.item()]:.5f} : {imagenet_labels[idx.item()]}', end='')

生成的图片长这样:
output_image

我们详细看看可视化代码里干了什么事情。

前期的网络处理流程如下图所示。
ViT_Base_16 Net Structure

核心部分处理score matrix 和Attention map的代码如下:

1
logits, att_mat = model(x.unsqueeze(0))

此行代码预测并获取注意力矩阵, 此时若运行print(len(att_mat)), 我们会得到输出为12,代表模型输出的Attention Map共有12个。这个数字与Transformer Encoder Block的个数是一致的。即Transformer强制要求输入输出为同一维度,这样才可以级联好多个Transformer块进行不断地网络学习与拟合。

1
att_mat = torch.stack(att_mat).squeeze(1)

这是在对attention map list进行拼接,降维处理。这里给出torch.stack API的用法:

1
2
3
4
5
6
7
8
9
10
torch.stack(tensors, dim=0, *, out=None) → Tensor
Concatenates a sequence of tensors along a new dimension.
All tensors need to be of the same size.
SEE ALSO
torch.cat(): concatenates the given sequence along an existing dimension.
* Parameters:
tensors (sequence of Tensors) – sequence of tensors to concatenate
dim (int, optional) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive). Default: 0
* Keyword Arguments
out (Tensor, optional) – the output tensor.

不传入dim参数时默认为0,故将这个list以第1个维度进行拼接。之后再进行squeeze(1)操作,得到的是除去指定位置的维数为1的维度输出(这里比较绕QAQ)。运行print(att_mat.shape),得到的输出是torch.Size([12, 12, 197, 197])

我猜测的数据流是这样的:
Stack and Squeeze

1
att_mat = torch.mean(att_mat, dim=1)

按照第二个维度(12)进行均值计算,最后得到的理论上是12张一样的平均过后的attention map。

1
2
3
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

(以下来源:ChatGPT)
在 ViT 中,注意力矩阵 AA 表示每个token对其他token的注意力权重。通过添加单位矩阵 II

augattmat=A+Iaug_{att_{mat}}=A+I

其中,II 是一个单位矩阵(对角线上元素为1,其他元素为0),我们确保了每个token对自己的注意力权重至少为1。这种方式可以看作是残差连接的一种形式,它保证了每个token的原始信息在经过多个注意力层后仍能保留

添加单位矩阵后,需要重新归一化注意力矩阵,使得每一行的和为1,这是因为在计算注意力时,我们通常希望每个token的注意力权重之和为1,以保持概率分布的性质。

{aug_{att_{mat}}}=\frac{aug_{att_{mat}}}{\sum\(aug_{att_{mat}},dim=-1,keepdim=True\)}

re norm

通过这两步操作,我们得到了一个增强的注意力矩阵(augmented attention matrix),它不仅考虑了每个token对其他token的注意力权重,还保留了每个token对自身的注意力权重,从而能够更好地捕捉不同层之间的信息流和依赖关系。

1
2
3
4
5
joint_attentions = torch.zeros(aug_att_mat.size()) # 初始化联合注意力矩阵
joint_attentions[0] = aug_att_mat[0] # 设置第一层的联合注意力矩阵

for n in range(1, aug_att_mat.size(0)): # 递归相乘
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

这两行进行递归累乘操作。首先初始化了一个与增强注意力矩阵(aug_att_mat)尺寸相同的零矩阵,用于存储每层的联合注意力矩阵。在第一层时,联合注意力矩阵就是增强注意力矩阵本身。从第二层开始(n=1),通过递归地相乘当前层的增强注意力矩阵和前一层的联合注意力矩阵,来更新联合注意力矩阵。具体来说,当前层的联合注意力矩阵等于当前层的增强注意力矩阵乘以前一层的联合注意力矩阵。

为什么要递归相乘?递归相乘的目的是为了累积每层注意力矩阵的影响。通过这种方式,最终得到的联合注意力矩阵能够综合考虑所有层次的注意力关系,捕捉输入token之间更复杂的依赖关系。

用图表示joint_attn计算过程如下:
Joint Attention Calculation

1
v = joint_attentions[-1]

取最后一个作为整个模型的联合注意力矩阵.

1
2
3
4
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")

先计算每个Patch的边长,这里为14,再重塑注意力矩阵,提取最后一层的联合Attention Map中第一个token[CLS]对所有token的注意力权重, 将这个一维数组reshape为二维数组,形成一个网格,每个值表示对应位置的注意力权重。通过cv2.resize函数将注意力掩码调整到与原图像相同的大小。mask / mask.max() 将注意力掩码归一化到0-1之间。[…, np.newaxis] 为了使mask具有三个通道,这样可以与原图像进行逐像素乘法。

计算过程如下图:
reshape and resize

接着我又改了一个代码块出来,能输出12个Block的特征图输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
attention_maps = []
for i in range(0, aug_att_mat.shape[0]):
tmp_attn_map = aug_att_mat[i]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = tmp_attn_map[0, 1:].reshape(grid_size, grid_size).detach().numpy()
# print(mask.shape)
mask = cv2.resize(mask / mask.max(), im.size)
mask = (mask * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
output_filename = f'attention_mask_{i}.png'
attention_maps.append(heatmap)
# cv2.imwrite(output_filename, heatmap)

map_height, map_width, _ = attention_maps[0].shape
num_maps = len(attention_maps)
canvas_height = map_height
canvas_width = map_width * num_maps

canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)

for i, heatmap in enumerate(attention_maps):
canvas[:, i * map_width:(i + 1) * map_width, :] = heatmap

cv2.imwrite('attention_maps_canvas.png', canvas)

输出的12张图并起来长这样:
attention maps
确实证明了ViT存在Attention Map中有Artifacts存在。

SNN-based Transformer有没有类似的现象呢

要想得到类似ANN的Attention Map,我们需要知道在传统ViT中Attention Score矩阵是怎么算的。

ViT中通过输入得到的 QQ , KK , VV 计算得到attention score,公式如下:

score=QKTdscore = \frac{QK^T}{\sqrt{d}}

dd 为MHA的头的个数,再经过softmax得到有概率意义的attention map:

map=Softmax(Score)map = \textbf{Softmax}(Score)

之后的处理就是同上面的步骤,最后得到可视化的attention map。

在Spike-driven Transformer中,针对脉冲编码性质的输入而言,计算attention score的方式略有差别:以V2(Metaspikeformer)为例,该工作选用作者提出的四种SDSA(Spike-driven Self Attention)中的第三种(SDSA-3)

V^S=SDSA3(Qs,Ks,Vs)=SN(Qs(KsTVs))=SN((QsKs)TVs)\hat{V}_S = SDSA_3(Q_s,K_s,V_s) = SN(Q_s(K_s^TV_s)) = SN((Q_sK_s)^TV_s)

写到这里,我又翻出来了V1的文章,才发现作者们在附录里是有写怎么定义Attention Map的(笑死)

作者在附录中指出,SDSA计算是一种Hard Attention,可以被理解成是在脉冲形式的Value上掩盖不重要的通道(channels)。

关于Hard Attention(硬性注意力)和Soft Attention(软性注意力)的区别与联系参考漫谈注意力机制(二):硬性注意力机制与软性注意力机制

VSV_S同样有时间(Time,T)和头数(Head,H)两个维度,在这两个维度上做平均,输出是每个LIF的脉冲发放率。换言之,Attention Score调控脉冲发放频率。

在Spike-driven Transformer V1中,输入x(torch.size([T,B,C,H,W]))需要进行flatten(0,1),即把第一,第二维度展平,成为一个维度