68 Transformer【动手学深度学习v2】

这是一篇方便大家理解多头注意力的代码是的笔记
- 代码参考:https://blog.csdn.net/qq_44833392/article/details/122247313
- 推荐文章(百度飞桨,讲的很细致):https://paddlepedia.readthedocs.io/en/latest/tutorials/pretrain_model/transformer.html
class MultiHead(nn.Module): def __init__(self, n_head, model_dim, drop_rate): # n_head 有几层注意力机制 # model_dim 模型的维度 # drop_rate 随机丢弃率 super().__init__() self.head_dim = model_dim // n_head # 32//4=8 self.wq = nn.Linear(model_dim, n_head * self.head_dim) # [4*8] self.wk = nn.Linear(model_dim, n_head * self.head_dim) self.wv = nn.Linear(model_dim, n_head * self.head_dim) self.o_dense = nn.Linear(model_dim, model_dim) self.o_drop = nn.Dropout(drop_rate) self.layer_norm = nn.LayerNorm(model_dim) def forward(self, q, k, v, mask, training): # residual connect # q=k=v=[batch_size,seq_len, emb_dim]=[32,11,32] residual = q # 残差 # linear projection key = self.wk(k) # [batch_size,seq_len, num_heads * head_dim] value = self.wv(v) # [batch_size,seq_len, num_heads * head_dim] query = self.wq(q) # [batch_size,seq_len, num_heads * head_dim] # 将头分离出来 # [step,n_head,n,head_dim] = [batch_size,头的数量,seq_len,每个头的维度] query = self.split_heads(query) # [32,4,11,8] key = self.split_heads(key) # [32,4,11,8] value = self.split_heads(value) # [32,4,11,8] # 自注意力机制 点乘 context = self.scaled_dot_product_attention( query, key, value, mask) # [batch_size,seq_len, model_dim] # 再经过一个线性变化 o = self.o_dense(context) # [batch_size,seq_len, model_dim] # 随机使得一些权重失效 o = self.o_drop(o) # layer normalization o = self.layer_norm(residual+o) return o def split_heads(self, x): x = torch.reshape( x, (x.shape[0], x.shape[1], self.n_head, self.head_dim)) # x = [step,n_head,n,head_dim] return x.permute(0, 2, 1, 3) def scaled_dot_product_attention(self, q, k, v, mask=None): # [32,4,11,11] # dk = 8 dk = torch.tensor(k.shape[-1]).type(torch.float) score = torch.matmul(q, k.permute(0, 1, 3, 2)) / (torch.sqrt(dk) + 1e-8) # [step, n_head, n, n]=[32, 4, 11, 11] if mask is not None: score = score.masked_fill_(mask, -np.inf) self.attention = softmax(score, dim=-1) # [32, 4, 11, 11] context = torch.matmul(self.attention, v) # [step, num_head, n, head_dim] context = context.permute(0, 2, 1, 3) # [batch_size,seq_len, num_head, head_dim] context = context.reshape((context.shape[0], context.shape[1], -1)) return context # [batch_size,seq_len, model_dim]