欢迎光临散文网 会员登陆 & 注册

Swin Transformer源码解析(二)

2023-07-03 15:15 作者:0x435959  | 我要投稿

二、Transformer Block

数据经过patch_embed后接着进入TransformerBlock模块,TransformerBlock主要包含四个部分:NormLayer==>W-MSA/SW-MSA==>NormLayer==>MLP,内部各部分还使用残差连接。

1. Norm Layer

NormLayer默认使用LayerNorm,对最后一维归一化,即模型的维度C

2. W-MSA/SW-MSA

窗口自注意力和移位窗口自注意力,将patch的特征图划分成一个个window,然后再在每个window内部做自注意力,但是这样window和window之间无交互,所以又使用了移位窗口自注意力。

2.1 window_partition

类似把图片分成pacth的操作,这里将patch_embedding操作后的特征图按window划分,但不同的是patch_embedding中有个embedding的过程,是通过卷积实现的但是这里不需要,只是简单的分成window。window_partition操作是将图片的形式由(2,56,56,96)==>(2*8*8,7,7,96) 8*8就是window的数量,可以看出维度没有变化,且内部也没有任何神经元的连接。

至于为什么要乘以8*8,是因为后面要在window内部做注意力,window与window之间无关,所以直接乘到batch_size里面。

2.2 window_reverse

和window_partition的操作相反,将划分后的windows转回去,形状一样,对应位置也一样。因为做完自注意力之后要变成之前的形状,因为后面要做patch_merge,要转成patch的格式

2.3 window_attention

3、Norm Layer

4、MLP

很简单,就是全连接==>激活==>dropout==>全连接==>dropout

三、Patch Merge

就是将patch特征图变小,但是维度增加






Swin Transformer源码解析(二)的评论 (共 条)

分享到微博请遵守国家法律