RWKV: 大语言模型结构的另一种选择

文章首发于 网站机器翻译学堂
转载事宜请后台询问哦
作者 | 阮俊豪
单位 | 东北大学自然语言处理实验室

前言
Transformer[1]在诸多的NLP任务上产生了非常惊艳的效果,甚至逐渐辐射到CV领域(如Vision Transfomrer[2]),获得了学术界和工业界一致的认可。因此也被作为当下大语言模型结构的不二之选。无论是以BERT[3]为代表的,常用于分类任务的Encoder-only模型;亦或是解决生成类任务为主的Decoder-only模型GPT[4];或兼而有之的Encoder-Decoder架构的T5[5]模型,他们都采用了transformer的部分或完整架构。
尽管如此,Transformer作为大语言模型的标准架构选择,也存在一些不能忽视的缺陷,例如内存和时间复杂度都与输入序列的长度成平方,这很大程度的影响了大语言模型部署在端侧或资源受限设备的可能性。

随着Scailing Law[6]的提出,大公司开始将研究重心转移到如何建设能容纳更大参数规模的基础支撑设备,更深层的训练技巧保证梯度和训练稳定性,借助更大数据规模的有效微调方式。在力大砖飞,物理限制还没有看到尽头的当下,较少有人继续研究Transformer之外,还有没有更适合大语言模型的结构。一方面因为Transformer结构设计上具有很高效的GPU并行程度,另一方面是因为Transformer在各种NLP任务取得良好成果的情况下尝试其他模型的大参数训练有不小的试错成本。
一位独立研究员彭博[7],在2021年8月份,就提出了他的原始RWKV[8]构想,并在完善到RKWV-V2版本之后,在reddit和discord上引发业内人员广泛关注。现今已经演化到V4版本,并充分展现了RNN模型的缩放潜力。本篇博客将介绍RWKV的原理、演变流程和现在取得的成效。
RWKV模型原型
An Attention Free Transformer[9]
标准AFT
RWKV模型中的time-mix的设计受An Attention Free Transformer工作影响较深,所以开始之前,我们先介绍一下Attention Free Transformer(AFT)。
这篇文章之所以名字中带有Free,是因为它完全去除了标准transfomrer中的点乘注意力,同时也没有采取其他linear transformer工作中常见的点乘注意力近似方法。
我们把标准的多头自注意力中每个头的输出结果表示为:
其中
是第 i 个头的线性变换。 是非线性函数,默认情况下是softmax函数。
AFT中,Q、K、V来源仍然是输入的线性变换结构。
我们把记作模型的输出,那么模型 t 时刻的输出在AFT中表示为:
其中指的是element-wise乘法,就是按位相乘
。
是sigmoid函数,和softmax的区别是,sigmoid是一个个矩阵元素计算的,softmax是一行行计算的。
就是一个可学习的位置偏置。
我们可以粗略地将这个过程提炼为 ,和原始的点乘注意力 相比,元素乘的时间开销显著降低,且不需要显式地计算和保存softmax出来的权重矩阵,还保留了全局的K、V之间的交互。需要特别注意的是,和原始多头自注意力不同。AFT的softmax是以列(时间维度)做归一化的,类似于池化。所以下文RWKV-V1也把这个设计称为Time-mix。原始多头自注意力是在隐藏层维度上做的。
对于每一个输入 的位置
,AFT形成了以
为权重的
的加权平均数。
为了进一步阐述AFT和多头注意力的联系,我们按照标准多头自注意力的形式,来描写时刻第
个注意力头的输出
在原始AFT的工作中,仍有两点关于的内容需要厘清。首先,AFT中
,
应该指代的是从
中取出第
行和第
列的元素。其次,公式中
与
的加和基本上都采用广播方式。以
为例子,
,
,因此他们的加和可能实际上需要由
做转置,然后在新的列维度上扩充至
方可运算。除此之外,为了简单描述,上述过程并没有讨论mask的置入方案,在实际操作中,可以通过控制求和范围来等价实现
AFT-local
在很多场景中,局部性都是一个很重要的归纳偏差。也有一些工作基于这个性质开展,比如OpenAI现在使用在GPT系列上的sparse transformer。AFT-local发现,训练后的transfomrer的注意力模式更倾向于集中局部。为了更详细的表达这个理念,我们用一副图可视化Vision Transformer(ViT)的注意力矩阵。这是一个12层6注意力头数的ViT在256张图片上的平均注意力模式,其中纵向维度是层数(2层一统计),横向维度是注意力头数。星光亮度越大的地方代表了更高的注意力权重。

可以看到,这个图里展示了相当强的局部模式,这个观测引出了AFT的变体——AFT-local的设计。
和上文的区别是,将位置偏置的值做了一个区域限制:
这里s就是一个局部的窗口大小。
AFT-simple
AFT-local的一个极端模式就是令s=0,也就是完全从AFT中抽走了位置偏置,从而得到
AFT还有第三种变体,叫做AFT-conv,其更适用于图像任务。因篇幅所限,感兴趣的读者可以查阅原文了解。
GLU Variants Improve Transformer
channel-mix的部分则受该节工作启发。尤其是其中的GeGLU。
标准Transformer中的feed-forward network采用的是如下的结构:
T5取走了最外层的偏置,修改为
也有其他工作尝试使用GELU或者Swish替代ReLU.
其中
在上面的例子里,两层可学习的线性变换是按顺序堆叠在一起的,一层的输出作为第二层的输入。后续Gated Linear Units(GLU)提出了另一种形式,这种形式如果省略掉激活函数也被称作bilinear layer。
同理,GLU上也存在一个非线性激活函数,我们可以使用GELU等函数去替换。RWKV-1所涉及的GeGLU就来自于GELU+GLU:
最后仍然采用省略偏置的结构替换FFN:
本文最后还测试了许多不同GLU变体作为FFN的效果,感兴趣的读者可以参看原文。
RWLV-V1
这个版本的工作还比较类似linear transformer的工作,而不是纯粹的RNN网络。在彭博的设计中,RWKV模型由交替的Time-mix和Channel-mix层组成。
两者均拥有类似的R\W\KV结构设计,故此得名。其中R\K\V由输入线性变换生成,W是一个可学习的参数矩阵。
笔者认为,和AFT工作中的标记方式不同,矩阵的下标不仅代表取元素,也同时代表维度表示。例如可以认为是
取出了第
行的元素。
可以看出,Time-mix层与AFT-simple基本相同,其区别包括,修改归一化
相较于原始的在完整时间序列上的归一化,Time-mix现在采用的是一种只回看历史序列的局部归一化。
除此之外,还将W分解,支持多头
channel-mix和上文提到的也基本相同,因为K和V是由层输入x线性变换得到的,因此相当于只是增加了一个额外的self-gating R,彭博测试后发现确实提高了拟合性能。
最后再加上了彭博2020年8月的一个想法 time-shift就组成了RWKV-V1版本。time-shift主要涉及到relative position embedding,以及把输入从改成
表面上看,这是强制要求网络结合 x[t] 和 x[t-1],是个 inductive bias,就像强制要求网络使用 2-gram。
但后来我做了更多实验,发现它对于深层模型也有效。而且,对于 SA 和 FFN 都有效。
这乍一看有点奇怪,把 x[t] 的一半通道,用 x[t-1] 的通道代替,是不是有点过分了?经实验,确实可以用这么强的混合。
后来再想想,我们训练 GPT 时,网络的 hidden representation 实际在做两件不同的事情:
1. 预测下一个字。有时这很简单(有时下一个字很明显)。
2. 收集前文的 context 信息,传递给后文。这永远是困难的任务。
这两件事有明显的区别。第2件事更难。
我的理论是:在加入 time-shift 后,没有 time-shift 的通道主要承担 (1),被 time-shift 的通道主要承担 (2)。所以,这就实现了更明确的信息分离。
而且 time-shift 可让信息快速传递,就像一个小卷积。在多层后可以看很长的距离。
上述这段引述自彭博对于time-shift的描述,我解释一下最后一句话,“在多层后可以看很长的距离”,假设层数有6层,在t时刻的第6层的输入,依赖于第5层t-1时刻和t时刻的输入,而第五层t-1时刻的输入,又依赖于第四层t-2时刻的输入。。。依次类推到第一层t-5时刻的输入。所以赋予了回看历史信息的能力。这个观点一般是在local attention的模式里会提到的,被用于解释这一段做法也挺契合。如果认为收集前文context的信息更难,time-mixing就在注意力回看的基础上,直接让输入特征也一并回看了。
对于普通的 QKV 自注意力,我观察权重,看各层对于两种通道的使用程度,发现 Q 偏向于使用【没有 time-shift 的通道】,而 V 偏向于使用【被 time-shift 的通道】,符合这个理论。
这个观测我觉得很有意思,如何观察Q和V对于不同层的偏向使用情况,我倾向于作者可能对线性变换矩阵做了一个类似热力图观察,发现不同通道部分存在不同的热力,比如前半部分通道更热,
后半部分更热。
该版本在 simplebooks-92 的 character-level 性能对比了灰色基线(普通 MHA 多头注意力 + Rotary encoding + GeGLU FFN),和加入各种魔改的MHA的黑色线版本,均有竞争力。

RWKV-V2
这个版本的改动有些大,我们先从自注意力层开始。为了方便,我把作者的伪代码图转成了模型结构图。和RWKV-V1相比,区别在于修改了time-mix的softmax处理,显式地加入了token-shift机制,从永远的当前词和历史词各取一半channel合并改成了可训练参数T调节,我们也可以把这个叫做shift门。



其中W是被预先计算好的值初始化的(具体怎么预先计算的彭博未提及,但是受启发于alibi编码),不同的channel使用不同的W,而且更小的W被用在更低的层上。这是因为底部层的平均衰减更快,对应短程信息;顶部层的平均衰减更慢,对应长程信息。
作者还给了一个提醒,要clamp k【是阈值裁剪,对上界限制了60】和对d加上来预防overflows。
接下来是FFN层,相较于RWKV-V1,这算是完全新增的层了。

每个时间步只依赖于 ,而
或
只依赖于
和
,摆脱了像传统RNN对历史状态的依赖,所以可以很方便的展开训练时并行。a、b分别代表kv和k的滑动平均数。c和d则是a、b加上self-attntion,同时也是记忆机制。T,K,V,R,W,X,P全是可训练的参数矩阵。
模型在实现时还有个headQK机制,会快速地看一遍前文,可以让模型从前文复制或避免某些字。
在v2这个阶段,LSTM的发明和奠基者Sepp Hochreiter也看到了这个工作,并给予了一定认可。

为了更好地阐述V2的内容,我们仍然着重展示RWKV-V2如何从降至
。同时讲述RWKV更向RNN靠近的理由。笔者在RWKV-V1的时候,曾说过V1还是很类似于linearize attention的工作,每一个时间步都依赖于历史时刻所有的输入。而rnn类型的模型则是完全依赖于上一个时刻的输入和当前时刻的输入,以及固定大小的某些状态(在RWKV-V2里是a和b)。以下是彭博给出的RWKV-V2的自注意力层简化表达
RWKV-V3
这个阶段是一个非常快的过度阶段,相较于前两版近一年的跨度,V3只持续了两个月左右,且文字资料较少。在RWKV-LM的项目中提到,R K V 的来源变化了。在V2版本中,是先生成一个mix的。
我们用一幅图简单展示一下这个变换。

另一个变换是,使用preLN替换postLN(更稳定且更快收敛)。不过好像preLN在V2已经采用了。
自注意力层的实现是:对x做time-shift,然后分别根据不同的矩阵映射成,再映射成R\K\V
RWKV-V4
虽然这版星在22年底就出现了,但是我们可以认为最后定档应该是出自这篇论文RWKV: Reinventing RNNs for the Transformer Era.框架的整体结构如下图所示:

time-mixing block的实现是:
开始仍然是和v3一样的映射操作,然后有一点小小的变化。就起到了transformer自注意力的
,但是时间复杂度是
,因为只需要遍历序列,每次操作都是简单的加法,时间复杂度是
。这里面的
也是一个位置偏置,就好像我们第一次在AFT中讨论的那样,不过现在
。要求w中的元素非负是为了让每个通道的权重相当于往历史时间上衰减。且越远的时间步,比如i=1,受到的遗忘力度就越大。
而channel-mixing block的实现是:
这里采用了squared ReLU。
但是还有一点非常重要,因为解决了时间复杂度的问题,transformer之所以如此强大,是因为它有很好的并行性质,能充分利用GPU。而就time-mixing block来看,是存在时间依赖的。也就是说训练时给定一句话,不能利用类似teacher-forcing的方法训练,因为它还依赖于前面时刻的状态。
对于这个问题,彭博采用了Simple recurrent units for highly parallelizable recurrence (SRU)里的思路,简单地说,只依赖于原始输入x的部分,都可以预先计算好,例如。而
这种element-wise product则可以按照batch和dimmension两个维度并行。达到一个相对完备的并行状态。
上面的写法看起来并不太像RNN,因为RNN并不会全局的回看所有序列,而是通常依赖于当前输入和来自历史的状态,不过,我们在V2的时候就知道,这些mixing都可以重写成RNN-block的形式。transformer解码的时候通常会利用一个kv缓存来获得一定的速度提升,但随着序列长度增加也会带来很多空间占用,RNN形式却不会碰到这种问题。比如,上面的我们可以重写成下面这种形式,它只依赖于输入
和状态
在训练的时候,RWKV采取并行模式,而在推断的时候,RWKV可以采取RNN模式。
下图是语言建模任务下,RWKV-LM的运行过程。

传统的RNN通过使用非饱和激活函数、门控机制、梯度裁剪、添加约束等多种方法来解决梯度稳定性问题,但RWKV通过类似于transformer和RNN的融合,本质上地具有了更稳定的梯度。RWKV包含全时间依赖的softmax操作有助于数值稳定和防止梯度消失。层归一化也在这方面起到了很重要的作用。论文附录中,作者给出了RWKV在梯度稳定性上的数学证明(详见附录F)。同时,这样的设计也能够以超过任何现有RNN的能力的方式实现深层堆叠,模型能够捕获跨不同抽象级别的更复杂的模式。

上面的图,展现的是位置衰减偏置在channel维度的衰减大小。可以看出,在后续层中,模型的上下文信息被几乎完整的保存和传播。而在低层中,衰减曲线很快下滑,提示底层比较关注局部信息。
下面的图则展示了信息的检索和传播路径,采用的是Locating and editing factual associations in GPT 中提到的方法。
运行模型一次,记录计算过程中课程的所有状态和激活情况。采用噪声破坏被试的输入嵌入(例子里用的是“埃菲尔铁塔”)。还原计算过程中某一层在某一个令牌处的状态和激活情况,记录模型输出正确答案( '巴黎')的对数概率。
与transformer不同,RWKV依赖于信息在时间维度上的递归传播。在这种情况下,"埃菲尔铁塔位于巴黎"这一事实在第4层被检索到。然后将其传递给后续的层。在第20层中,信息主要通过时间进行传播,直到到达需要的地方。最后,将其传递到最后一层进行答案的输出。
一些局限性
RWKV解决长程依赖问题,也就是传递历史上下文信息的机制,在RWKV-V4中有三种——递归、时间衰减和token shift。之所以在此处仍然重复已经提到的内容,是因为在这三种机制下,RWKV的长度外推效果仍然欠佳。
RWKV是靠记忆来完成任务的,也就是只会开卷考试不会闭卷考试。所以RWKV对prompt比较敏感,要把任务描述的token放到最前面,带着问题阅读材料效果才比较好。
当模型宽度继续加大时,线性RNN的时间复杂度可能更依赖于隐藏层维度d,使得标准的attention机制也在序列上接近线性了。
结语
彭博是一个工程实力非常强悍的独立研究员,早期就在知乎分享了非常多对于模型改进的思路和实现方案,将talk is cheap,show me the code展现的淋漓尽致。他能从大量的论文阅读中真正提取出别人的精华,并进行尝试,在积累一定量后进行融合,也由此诞生了RWKV。RWKV是一个非常有意思的模型,通过SRU的思路,解决了RNN训练并行效率的问题,通过AFT、time-shift、相对位置编码等多种思路融合,加强了RWKV的长程依赖且缓解了训练不稳定的问题。通过geglu、squred ReLU、FFN with R gate、自定义初始化等多种工作进一步强化了transformer中的FFN层。
在OpenAI实际上并不Open的当下,RWKV从21年创立之初完全开源,既受到开源社区很多帮助,也反哺了开源社区很多的成果。现在ChatRWKV已经在同尺寸上展现出了相当惊人的表现。对于LM基座感兴趣的读者,可以参看这个链接,而想在线体验的读者,也可以从这个链接直接体验RWKV-4-World-7B模型。
RWKV-V5的构想和改进计划也已在近日公布,相信在可预见的未来,大语言模型的结构选择除了transformer,也将会有完全由国人设计的RWKV的一席之地。

hi,这里是小牛翻译~
想要看到更多我们的文章,可以关注下
机器翻译学堂(公号或网站)
笔芯~

往期精彩文章


