欢迎光临散文网 会员登陆 & 注册

从零实现BERT、GPT及Difussion类算法-3:Multi-head Attention & Transformer

2023-04-24 00:37 作者:一代闲人  | 我要投稿

教程简介及目录见: 从零实现BERT、GPT及Difussion类算法:文章简介及目录

本章完整源码见https://github.com/firechecking/CleanTransformer/blob/main/CleanTransformer/transformer.py


这一章将参考《attention is all you need》论文,实现Multi-head Attention、LayerNorm、TransformerBlock,有了这章的基础后,在下一章就可以开始搭建Bert、GPT等模型结构了


Multi-head Attention

参考:https://arxiv.org/abs/1706.03762

Attention介绍(选读)

  • 先简单介绍下一般性的Attention,如果已经了解的同学可以跳过

  • Attention字面意思是注意力,也就是让模型能够学习到一个权重,来将输入选择性的传入下一层

  • 比较常用的操作如下:

  • 首先假定输入tensor为q, k, v,其中Size_q%20%3D%20Size_k%3DSize_v%20%3D%20(BatchSize%2C%20SeqLen%2C%20Dim)

  • self-attention是attention的一个特例:q=k=v

  • 以下给出基础attention的伪代码

Multi-head Attention基础原理

由以上论文截图可知,Size_q%20%3D%20Size_k%3DSize_v%20%3D%20(BatchSize%2C%20SeqLen%2C%20NHead*HeadDim)

所以实现步骤如下(可以和上文基础Attention对比着看):

  1. 对Q,K,V进行Linear:得到新的Q、K、V的size不变

  2. Multi-Head拆分:Size_q%20%3D%20Size_k%3DSize_v%20%3D%20(BatchSize%2C%20NHead%2CSeqLen%2C%20%20HeadDim)

  3. 使用Q、K计算Weight(其中第二行是Transformer在attention基础上增加的一个scaling factor)

    W_%7Bb%2C%20h%2Ci%2Cj%7D%20%3D%20%5Csum_tQ_%7Bb%2C%20h%2Ci%2Ct%7DK_%7Bb%2C%20h%2Ct%2Cj%7D%20%5C%5C%0AW%20%3D%20W%2Fsqrt(HeadDim)%20%5C%5C%0AW%20%3D%20Softmax(W)%20%5C%5C%0AW%20%3D%20Dropout(W)%20%5C%5C%0ASize_w%20%3D%20(BatchSize%2C%20NHead%2C%20SeqLen%2C%20SeqLen)%0A

  4. 使用Weight和V,计算新的V

    V_%7Bb%2Ch%2Ci%2Cj%7D%20%3D%20%5Csum_tW_%7Bb%2Ch%2Ci%2Ct%7DV_%7Bb%2Ch%2Ct%2Cj%7D%20%5C%5C%0ASize_v%20%3D%20(BatchSize%2C%20NHead%2C%20SeqLen%2C%20HeadDim)

  5. 对V进行维度变换

    Size_v%20%3D%20Size_%7Bv.transpose(1%2C2)%7D%20%3D%20(BatchSize%2C%20SeqLen%2C%20Nhead%2C%20HeadDim)%20%5C%5C%0ASize_v%20%3D%20Size_%7Bv.view()%7D%20%3D%20(BatchSize%2C%20SeqLen%2C%20NHead*HeadDim)

Multi-head Attention实现代码

  • 代码不是太复杂,结合上文和注释,应该能很容易看懂


LayerNorm

参考

  • https://arxiv.org/abs/1607.06450

  • https://arxiv.org/abs/1607.06450

  • https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html

  • https://blog.csdn.net/xinjieyuan/article/details/109587913

BatchNorm与LayerNorm的差异

  • batch normalization

  • 对每一个输入,先在mini-batch上计算输入的均值、标准差,然后当前层的每个输入值使用均值、标准差进行正则计算

  • 公式如下

  • 先在mini-batch上计算出每个位置的均值%5Cmu_i%3D%5Cfrac%7B1%7D%7BM%7D%5Csum%5EM_%7Bi%3D1%7Da_i、标准差%5Csigma_i%20%3D%20%5Csqrt%7B%5Cfrac%7B1%7D%7BM%7D%5Csum%5EM_%7Bi%3D1%7D(a_i-%5Cmu_i)%5E2%7D,其中M为mini-batch大小

  • 然后对每个值应用变换%5Cbar%7Ba%7D_i%20%3D%20%5Cfrac%7Ba_i-%5Cmu_i%7D%7B%5Csigma_i%7D

    • 提示:这里之所以有下标i,是因为batch normalization是在batch内,对不同样本的相同位置进行归一

  • layer normalization

    • batch normalization是在batch内,对不同样本相同位置进行归一;而layer normalization是在layer内,对同一个样本不同位置进行归一

    • batch normalization不在整个mini-batch上计算均值、标准差,而是在当前层的当前样本计算输入的均值、标准差,然后对当前层的当前样本输入值使用均值、标准差进行正则计算(也可以理解为Layer Normalization是和batch无关,而是对每个样本单独处理)

    • 公式如下

    • 先在单个样本上计算出每一层的均值%5Cmu%5El%3D%5Cfrac%7B1%7D%7BH%7D%5Csum%5EH_%7Bi%3D1%7Da%5El_i、标准差%5Csigma%5El_i%20%3D%20%5Csqrt%7B%5Cfrac%7B1%7D%7BH%7D%5Csum%5EH_%7Bi%3D1%7D(a%5El_i-%5Cmu%5El)%5E2%7D,其中H为当前layer的大小hidden units数量

    • 然后对每个值应用变换%5Cbar%7Ba%7D%5El_i%20%3D%20%5Cfrac%7Ba%5El_i-%5Cmu%5El%7D%7B%5Csigma%5El%7D


    Layer Normalization代码实现

    • 代码如下

    • eps为一个较小值,是为了防止标准差std为0时,0作为除数

    • 从上文公式看出,标准差是计算(a%5El_i-%5Cmu%5El)%5E2的均值后开根号,所以代码中有std = self._mean((x - mean).pow(2) + self.eps).pow(0.5),是复用了self._mean()的计算均值操作

    • 为了和pytorch的LayerNorm保持一致,这里同样可以接受normalized_shape参数,表示Normalization要在哪几个维度上进行计算,可以结合_mean()函数中的代码进行理解

    TransformerBlock

    参考

    • https://arxiv.org/abs/1706.03762

    Transformer原理

    • 从《Attention Is All You Need》论文中这张图可以看出以下几点信息:

    • Encoder、Decoder基本相同,最大差别是Decoder上多了一层Multi-Head Attention

    • 每一个TransformerBlock只由Multi-Head AttentionAddLayerNormLinear这4种操作组合而成

    • 在上文已经实现的Multi-Head Attention、LayerNorm基础上,再来实现TransformerBlock就很简单了

    • 为进一步简化,在本章我们先只实现Encoder,并且省略掉mask等额外操作。到之后讲到GPT时再来实现Decoder以及更完善的TransformerBlock

    Transformer代码实现

    • 代码主要由attention+Add+Norm,以及FFW+Add+Norm这2个部分组成,其中ffw是两层全连接中间夹一个ReLU激活函数

    • 从以上代码看出TransformerBlock还是非常简洁的,而Bert、GPT等模型就是对TransformerBlock的堆叠,这部分内容将放在下一章讲解


    从零实现BERT、GPT及Difussion类算法-3:Multi-head Attention & Transformer的评论 (共 条)

    分享到微博请遵守国家法律