Transformer在Masked Self-attention中做的什么?(实现细节)
Transformer是一个训练与预测相互独立的模型,训练和预测的不同主要反应在masked self-attention模块的代码上,经过几个小时的研究终于搞懂,下面对该部分的实现细节记录。需要注意的是接下来提到的全部代码并非来自原始transformer项目,因此可能并不具有普适性,仅作为一种可行的思路介绍。
1.transformer中的训练与预测策略
在训练阶段一般采用teacher forcing的策略,所谓teacher forcing,就是直接将ground truth作为decoder的input,通过masked self-attention计算各时间步的context feature,然后将context与encoder_out共同送入cross-attention中进行跨模态建模。其中在Mashed SA内部使用mask策略避免计算context的过程看到待预测结果(即避免看到gt中当前时间步之后的单词)。
在预测阶段常用的有两种策略,一种被称为beam_search,另一种忘记了名字,预测策略不是本文的重点,因此不对这方面做更多叙述,仅介绍不使用任何策略的最普通的预测方式。做预测时,需要从无到有进行语言的生成,即需要设置若干个时间步,每个时间步中将已预测出的结果作为输入,以此为指导预测下一个单词,这个过程与RNN非常相似,因此transformer在预测时无法体现如训练阶段一般的高并行优势。
2.直觉与实现的区别
从直觉上看,预测与训练并无太大不同,不过要增加一个有关时间步的训练而已,而在实现中却并非如此,以时间步step=3来举例:
2.1直觉
当step=3时,已经完成了对前面3个单词的预测,计划预测第4个单词,因此需要将前3个单词作为decoder的input,假设batch_size=10,那么从直觉上讲送入decoder中input形状应该为(10,3),其中3表示已知单词的索引。后续在decoder中将通过word_embedding方法将单词索引转化为dim=512的向量,经过word_embedding处理后即可送入decoder_layers,其形状为(10,3,512).
在第1个decoder_layer中,首先使用masked self-attention进行处理,将形状为(10,3,512)的input作为query、key、value,经过运算后得到(10,3,512),记为context;接下来,使用cross-attention进行处理,context作为query,encoder_out作为key和value,经过运算后得到形状为(10,3,512)的张量。
在第2个decoder_layer中,进行同样的操作;第3个decoder_layer中,进行同样的操作。最终得到形状为(10,3,512)的张量,取(:,2,:)过Linear(512,voc_len),而后做softmax归一化,作为预测结果的概率。
2.2实现
当阅读代码时,却发现在实现细节方面并非如上面叙述的一般。
其中第1行暂时忽略,稍后解释。第2行和第3行显然是根据计划生成的最大序列长度建立for循环,在每个循环中进行一次预测。进行预测时调用了self.iter方法,在iter方法中真正做预测的代码如下
此处的self.model表示transformer类的对象,self.model.step是定义在transformer中的一个方法,实际上是对transformer.decoder进行调用,在self.model.step的最后一行代码为
显然,在调用transformer.decoder的传入参数中,it即为decoder的input,self.enc_output为编码器的输出。依照直觉,此处的it的形状应该随着时间步的变化而变化,例如当step=1时,由于已经预测出的单词为一个,故it.shape=(bs,1);当step=3时,由于已经预测出了三个单词,故it.shape=(bs,3).
然而,经过对现有代码的调试发现it的形状始终为(bs,1),不随时间步step的增大而变化,这与直觉不符,因为这相当于每个时间步中仅将上一个时间步的预测结果进行传入,而无法关注到以往预测到的全部单词,这与预期严重不符。(经过研究后发现虽然it的shape始终为(bs,1),但在实际预测时还是对所有已知的单词进行了考虑,下面对方法进行介绍。)
经过观察,发现在decoder layer中的时间代码为,显然其可以分为三个部分,即Masked self-attention,Cross-attention,FFN,其中masked中作为query,key,value的传入参数均为input,其形状为(bs,1,dim)
将参数传入上述的self.self_att后,对于接收到的query、key、value使用下面的代码,再次进行一次调用,在这次调用的self.attention方法内才会真正进行softmax(QK)V的注意力运算。
值得注意的是,在此处调用self.attention时传入的queries、keys、value形状分别为(以step=3为例):(bs,1,dim)、(bs,3,dim)、(bs,3,dim),即在key和value处神奇的对已有的全部单词做了考虑,而在上一步中分明将同一个形状为(bs,1,dim)的input同时作为self.self_att传入参数的query、key、value,这中间发生了什么?
可以看到在接收到keys和value后,先令其与self.running_keys和self.running_values拼接,然后再赋值给keys和values。由于每个时间步中均会进行这样的操作,因此当处于step=3时,self.running_keys和self.running_value中将会存储有第1个单词和第2个单词,新接收到的keys和values中储存有第三个单词,将二者拼接即为全部的已知单词。
到这里,已经介绍完了transformer在进行预测时的操作,下面补充一些细节
3.补充细节
3.1需要在对每个batch预测前将self.running_keys和self.running_values置为(bs,0,dim)的形式,如何做到的?
在初始化函数init中使用上述代码进行定义,其中self.register_state的定义如下
可以看到,会将第一个参数name保存到列表self._state_names中,将第二个参数default保存到字典self._state_defaults中。然后调用nn.model中的self.register_buffer方法,创建一个名为self.name的变量,并使用default为其赋值,同时令其梯度为False,即不会被优化。
通过上述代码,可以得到形状为(0,dim)的张量,还缺少batch_size维度,同时上面代码仅说明了如何进行初始化,而为介绍如何在对每个batch预测前进行置0.
在本文贴出的第一段代码中,曾说暂时忽略第一行。此处再次将该段代码贴出,以针对第一行进行说明。
该行代码对应的操作如下
通过递归的方式令所有的child调用enable_statefullness,然后调用self._init_states()方法,代码如下
其效果是利用self._state_defaults中储存的数据为self._buffers赋值,上面我们介绍register_state时提到过利用(0,dim)的张量为self._state_defaults列表赋值,后续没有对其修改的操纵,因此其中储存的数据始终为(0,dim)的张量。
self._buffers中储存的是self.running_keys和self.running_value的值,在这个init函数中利用self.state_defaults为self._buffers赋值,实际上就重置self.running_keys和self.running_value中储存的数据。
因此,在对时间步建立for循环的外面进行这个操作,就可以实现对每个batch预测前将self.running_keys和self.running_values置为(bs,0,dim)的形式。
3.2 self._is_stateful与self.can_be_stateful
之前提到的代码中,有写分支if需要对这两个值做判别,下面介绍这两个值的赋值情况。
self.can_be_stateful是初始化对象时的传入参数,仅在self-attention用作masked SA时将其设置为True,其余时候均为False。
self._is_stateful在上述的enable_statefullness方法中被置为True,位置通常为对每个batch的时间步建立循环之前。
(完)