Swin Transformer源码解析(二)
二、Transformer Block
数据经过patch_embed后接着进入TransformerBlock模块,TransformerBlock主要包含四个部分:NormLayer==>W-MSA/SW-MSA==>NormLayer==>MLP,内部各部分还使用残差连接。
1. Norm Layer
NormLayer默认使用LayerNorm,对最后一维归一化,即模型的维度C
2. W-MSA/SW-MSA
窗口自注意力和移位窗口自注意力,将patch的特征图划分成一个个window,然后再在每个window内部做自注意力,但是这样window和window之间无交互,所以又使用了移位窗口自注意力。
2.1 window_partition
类似把图片分成pacth的操作,这里将patch_embedding操作后的特征图按window划分,但不同的是patch_embedding中有个embedding的过程,是通过卷积实现的但是这里不需要,只是简单的分成window。window_partition操作是将图片的形式由(2,56,56,96)==>(2*8*8,7,7,96) 8*8就是window的数量,可以看出维度没有变化,且内部也没有任何神经元的连接。
至于为什么要乘以8*8,是因为后面要在window内部做注意力,window与window之间无关,所以直接乘到batch_size里面。
2.2 window_reverse
和window_partition的操作相反,将划分后的windows转回去,形状一样,对应位置也一样。因为做完自注意力之后要变成之前的形状,因为后面要做patch_merge,要转成patch的格式
2.3 window_attention
3、Norm Layer
4、MLP
很简单,就是全连接==>激活==>dropout==>全连接==>dropout
三、Patch Merge
就是将patch特征图变小,但是维度增加