欢迎光临散文网 会员登陆 & 注册

68 Transformer【动手学深度学习v2】

2023-03-07 13:34 作者:小歧鹿  | 我要投稿

这是一篇方便大家理解多头注意力的代码是的笔记

  • 代码参考:https://blog.csdn.net/qq_44833392/article/details/122247313
  • 推荐文章(百度飞桨,讲的很细致):https://paddlepedia.readthedocs.io/en/latest/tutorials/pretrain_model/transformer.html
class MultiHead(nn.Module):
    def __init__(self, n_head, model_dim, drop_rate):
        # n_head 有几层注意力机制
        # model_dim 模型的维度
        # drop_rate 随机丢弃率
        super().__init__()
        self.head_dim = model_dim // n_head     # 32//4=8
        self.wq = nn.Linear(model_dim, n_head * self.head_dim)  # [4*8]
        self.wk = nn.Linear(model_dim, n_head * self.head_dim)
        self.wv = nn.Linear(model_dim, n_head * self.head_dim)

        self.o_dense = nn.Linear(model_dim, model_dim)
        self.o_drop = nn.Dropout(drop_rate)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, q, k, v, mask, training):
        # residual connect
        # q=k=v=[batch_size,seq_len, emb_dim]=[32,11,32]
        residual = q    # 残差

        # linear projection
        key = self.wk(k)    # [batch_size,seq_len, num_heads * head_dim]
        value = self.wv(v)  # [batch_size,seq_len, num_heads * head_dim]
        query = self.wq(q)  # [batch_size,seq_len, num_heads * head_dim]

        # 将头分离出来
        # [step,n_head,n,head_dim] = [batch_size,头的数量,seq_len,每个头的维度]
        query = self.split_heads(query) # [32,4,11,8]
        key = self.split_heads(key)     # [32,4,11,8]
        value = self.split_heads(value) # [32,4,11,8]

        # 自注意力机制 点乘 
        context = self.scaled_dot_product_attention(
            query, key, value, mask)    # [batch_size,seq_len, model_dim]

        # 再经过一个线性变化
        o = self.o_dense(context)       # [batch_size,seq_len, model_dim]
        # 随机使得一些权重失效
        o = self.o_drop(o)
        # layer normalization
        o = self.layer_norm(residual+o)
        return o

    def split_heads(self, x):
        x = torch.reshape(
            x, (x.shape[0], x.shape[1], self.n_head, self.head_dim))
        # x = [step,n_head,n,head_dim]
        return x.permute(0, 2, 1, 3)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        # [32,4,11,11]
        # dk = 8
        dk = torch.tensor(k.shape[-1]).type(torch.float)
        score = torch.matmul(q, k.permute(0, 1, 3, 2)) / (torch.sqrt(dk) + 1e-8)                 # [step, n_head, n, n]=[32, 4, 11, 11]
        if mask is not None:
            score = score.masked_fill_(mask, -np.inf)
        self.attention = softmax(score, dim=-1)     # [32, 4, 11, 11]
        context = torch.matmul(self.attention, v)   # [step, num_head, n, head_dim]
        context = context.permute(0, 2, 1, 3)       # [batch_size,seq_len, num_head, head_dim]
        context = context.reshape((context.shape[0], context.shape[1], -1))
        return context                              # [batch_size,seq_len, model_dim]


68 Transformer【动手学深度学习v2】的评论 (共 条)

分享到微博请遵守国家法律