Swin Transformer论文精读【论文精读】

swin transformer代码库
一系列更新,swin transformer 3月份上传论文,4月份代码库出来
紧接着5月12号又放出来了自监督版本的Swin Transformer--moby,其实就是把MoCo的前两个字母和 BYOL 的前两个字母合在了一起,从方法上和性能上其实和MoCo v3和DINO都差不多,只是换了个骨干网络
接下来过了一个月,Swin Transformer就被用到了视频领域,推出了Video-Swin-Transformer,在一系列数据集上都取得了非常好的效果
- 比如说在 k-400这个数据集上就已经达到了84.9的准确度
7月初的时候,因为看到了有 MLP Mixer 这篇论文,把 Swin 的思想用到了 MLP 里,推出了 Swin MLP
8月初的时候,把 Swin Transformer 用到了半监督的目标检测里,然后取得了非常好的效果
10月份的时候获得了ICCV 的最佳论文奖
12月份受到了 BEiT 和 MAE 的推动,用 Swin Transformer 基于掩码自监督学习的方式做了一个叫 SimMIM 的论文
效果炸裂
Swin Transformer 的提出主要是用来做视觉的下游任务,所以主要看一下 COCO 和 ADE20K这两个数据集上的表现
下图 COCO 数据集上的表现

ADE20K数据集

研究动机:证明transformer是一个通用骨干网络,可以用于所有视觉任务

vit缺陷:虽然可以通过全局自注意力操作达到全局建模能力,但是对多尺寸特征的把握会弱一些,不适合处理密集预测任务,全局自注意力对于视觉任务有点浪费资源
检测和分割任务处理多尺寸特征的方法

降低复杂度:小窗口之内算自注意力
如何生成多尺寸特征?
CNN有pooling操作,可以增大卷积核的感受野,从而使得每次池化的特征抓住不同物体的特征
swin transformer提出类似于池化的patch merging,这样合并的大patch内容可以看到之前4个小patch看到的内容
swin的一个关键因素:滑动窗口

模型前向过程

linear embedding:把向量维度变成一个预先设置好的值(Transformer能够接收),论文里把超参数设置为C(网络总览图C=96)
56 x 56=3136 拉直成序列长度,96是每一个token的向量维度
vit的patch size=16 x 16,序列长度=196,3136太长了,不是TRM可以接受的
所以swin transformer引入基于窗口的自注意力,每个窗口按照默认来说有7 x 7=49个patch,序列长度=49非常小,解决了计算复杂度的问题,暂时把transformer block当做一个黑盒,我们只关注输入和输出的维度
想要构建多尺寸信息,需要层级式的transformer block,也就是CNN中的池化操作
patch merging操作

顾名思义,把邻近的小patch合并成一个大patch,就可以起到下采样特征图的效果
这里我们要下采样2倍,所以我们选点时每隔一个点选一个
假如说原来的张量是HxWxC,那么经过这次采样之后得到4个张量,每个张量大小(H/2)x(W/2),尺寸缩小1倍,将张量在C的维度上拼接起来,相当于用空间上的维度换了更多通道数
为了和CNN保持一致(resnet和vggnet一般在池化操作之后,通道数翻2倍),用1 x 1卷积把通道数4C变成2C,空间大小减半,通道数x2,就和CNN完全对等起来
基于窗口(移动窗口)的自注意力
全局自注意力:会导致平方倍的复杂度,(对于视觉的下游任务,尤其是密集型的任务,或者遇到非常大尺寸的图片,全局计算自注意力的复杂度就非常贵)
窗口自注意力:

每一个橘黄色的方格是一个窗口(不是最小计算单元),最小计算单元是patch,每一个窗口里有M x M个patch(论文里M=7),所有的自注意力计算都是在小窗口里完成的(序列长度永远=7x7=49),原来大的整体特征图会有多少窗口?8 x 8=64
我们会在64个窗口里分别计算自注意力
基于窗口自注意力的计算复杂度如何?

(1)标准的多头自注意力的计算复杂度,h=w=56
(2)基于窗口的自注意力,M=7(一个窗口某条边上有多少patch)
公式推算(1)

公式推算(2)(直接套用公式1)
现在高度和宽度不再是hxw了,而是MxM,将M值带入公式1

两公式差别,56 x 56 和7 x 7相差巨大,窗口自注意力很好的解决了计算量问题
新问题:窗口和窗口之间没有通信了,这样达不到全局建模了,会限制模型的能力,我们希望窗口和窗口之间通信起来
作者提出移动窗口的方式

transformer block的安排是有讲究的,每一次先做一个基于窗口的自注意力,然后再做一个基于移动窗口的自注意力,这样就打到窗口和窗口之间的通信

如图3b

两个swin transformer blocks加起来才是ST的一个基本单元,这就是为什么2、2、6、2都是偶数
目前的移动窗口有哪些问题?为甚作者要提高计算性能?

原来算的时候特征图上只有4个窗口,做完移动窗口后得到9个窗口,窗口数量增加,且大小不一
怎么让移位完的窗口数量还保持4个,每个窗口中patch数量保持一致?

作者提出掩码方式,当我们得到移位后的9个窗口之后,我们不在9个窗口上算自注意力,我们再做一次循环移位,得到的窗口数还是4个,计算复杂度固定了
新问题
掩码操作怎么做?

左边:经过循环位移后的操作

swin transformer的几个变体
实验:目标检测结果