大前端

前端学习之家-大前端

[Transformer]Fastformer:Additive Attention Can Be All You Need

Fastformer:加性注意力就是你所需要的


  • Abstract
  • Section I Introduction
  • Section II Related Work
    • Part 1 Transformer and Self-Attention
    • Part 2 Efficient Transformer
  • Section III Fastformer
    • Part 1 Architecture
    • Part 2 Complexity Analysis
  • Section IV Experiments
    • Part 1 Datasets
    • Part 2 Effective Comparisons
    • Part 3 Efficiency Comparison
    • Part 4 Influence of Interaction Function
    • Part 5 Influence of Parameter Sharing
  • Section V Conclusion and Future Work

Abstract

Transformer是一个强大的文本理解模型,但是其计算复杂度是输入序列长度的二次项,计算效率十分低下。尽管已经有诸多Transformer加速的方法,但是要么在长序列上表现不好,要么加速效果不尽人意。
本文提出的Fastformer则是基于加性注意力的模型,首先本文使用加性注意力机制来建模全局上下文关系,然后根据每个token与全局其他token的交互进一步对token进行转换,这样计算复杂度只要线性即可完成注意力的计算。
本文在5个数据集上验证了Fastformer的有效性,比目前现有的Transformer要更加高效,同时在长序列文本上实现更好的建模功能。

Section I Introduction

Trabsformer及其变体在诸多领域大获成功,比如BERT、GPT是目前NLP领域的基准模型;并且在视觉相关任务中Transformer也大放异彩。Transformer的核心是SA的计算,可以对整个输入序列的依赖关系进行建模。但是由于SA需要将当前输入与所有输入之间进行点积,复杂度是输入序列长度的平方项,因此在处理长序列文本时效率十分低下。

目前在Transformer加速方面已经进行了系列尝试,如BigBird采取稀疏的思路,通过计算局部注意力和部分位置的全局注意力以及随机注意力来近似密集注意力的计算;但是这种稀疏注意力一般无法完全模仿全局注意力。

Linformer则是用低秩近似计算注意力矩阵,但是这会减弱Transformer的建模能力。此外,这些加速方法在 处理长序列时效率依旧很低。


本文提出的Fastformer是一种基于加性注意力的Transformer模型,可以有效地在线性时间复杂度内计算注意力。

首先本文使用加性注意力机制将输入序列总结为一个全局的查询向量(query vector);随后再对global vector和所有key之间计算元素级别的点积来学习全局的上下文敏感的key matrix;并通过加性注意力进一步将其总结为全局的key-vector。


随后再对key和value进行元素级别的聚合,在进行线性变换得到全局的上下文敏感的注意力值。
最后将最开始的query与attention value相加得到最终的输出。


本文的这种方法在5类基准数据集上进行了实验,包括情感分类、主题预测、新闻推荐和文本总结。实验结果充分证明了Fastformer的高效性。

本文工作总结如下:


(1)本文提出一种高效的Fastformer模型,据我们所知是目前效率最高的Transformer模型;



(2)本文提出通过元素相乘建模全局上下文与token之间的交互,可以更高效的建模上下文信息;



(3)实验证明Fastformer比许多当前的Transformer模型都更为高效。

Section II Related Work

Part 1 Transformer and Self-Attention

Transformer基于多头注意力建立起来的,MHA可以有效捕获所有输入之间的交互关系,多头注意力是一系列单头注意力级联的结果。
从注意力的计算公式可以看出,计算复杂度是输入序列长度的平方项,这也成为了限制Transformer处理长序列的瓶颈所在。

Part 2 Efficient Transformer

近年来已经有诸多工作探究高效Transformer的设计。比如Longformer采用的是滑爽注意力和部分全局注意力混合的方式建模全局注意力;Big Bird计算的则是局部注意力和部分位置的全局注意力。但是在长序列时也需要引入更多的token才能避免性能下降。



另一思路则是基于哈希来加速SA的计算。比如Reformer使用的是multi-rounr哈希策略将相似的特征表述放在同一个桶中来计算SA,这样将SA的计算复杂度降到O(NLOGN),但是Reformer只适合处理特定长度的序列。
此外还有其他方法,如基于近似的方法。



Linformer使用低秩近似的方法;Linear Transformer则使用基于核的SA计算公式,借助矩阵之间的关联来近似计算注意力,但是这仍然是上下文无关的SA计算方式,并不是建模文本关系的最佳方法;此外当序列很长时依旧会有很大的计算成本。




Fastformer与上述方法都不同,是通过加性注意力来建模全局上下文,使用element-wisr product来建模输入之间的关联,这样可以显著降低计算成本,同时有效捕获上下文信息。

Section III Fastformer

Fig 1展示了Fastformer的结构框图。首先通过加性注意力将query序列浓缩为global query vector,计算global query vector和所有key的点积;继续使用加性注意力将key浓缩为global key vector,计算出attention value;随后将global key与attention value计算得到注意力值,经过现行颜射之后获得全局的上下文敏感的注意力值;最后与query 相加得到最终输出。

这样将SA的计算降到线性,同时也可以有效捕获全局上下文信息。

Part 1 Architecture

Fastformer最初依旧会将输入映射为Q,K,V;本文将注意力头设置为3个。


接下来,根据输入序列的上下文信息来捕获Q,K,V之间的依赖关系是核心问题,在原始Transformer中使用的是点积,从而使得复杂度为N的平方。



在这里插入图片描述

一种可能的优化方法是在计算交互关系之前使用总结过的Q,K,V,而加性注意力可以有效的总结重要信息。
因此本文首先利用加性注意力将query矩阵总结为一个global query vector:
在这里插入图片描述
在这里插入图片描述

然后使用golbal query与所有key之间的元素计算点积建模交互作用,并将它们组合成一个上下文感知的key矩阵。
接下来问题就变成了如何计算global query与key matrix之间的交互作用,到底是通过相加还是级联?但是这两种方法都无法区分到底是query还是key施加的影响。


而元素集相乘是建模向量之间关联的有效操作,因此本文计算的是query 与key matrix之间的element-wise production;最后依旧使用加性注意力获得总结后的global key。
在这里插入图片描述
在这里插入图片描述

最后建模global key与attention value之间的交互作用,依旧是使用element-wise product。受原始Transformer启发,本文对每一个key-value交互使用一个线性转换来学习其隐层表示。
得到的矩阵与query matrix相加获得Fastformer的输出。其中每个注意力头的输出是沿着隐藏层轴这一维度级联的。

可以看到在Fastformer中每一个key和value都会与global query或global key进行交互,从而学习上下文关系。通过叠加多个Fastformer层可以完整的建模上下文信息。
借助value和query的权重共享可以进一步减少内存需求;本文还进一步在不同Transformer层共享参数,从而进一步减少参数大小,降低过拟合风险。

Part 2 Complexity Analysis

下面进行复杂度分析。加性注意力的时间复杂度和内存需求均为O(N * d),附加参数总数为2hd。
element-wise product的计算复杂度也是O(N*d),总的复杂度为O(N * d^2);与原始的O(N ^ 2 * d)相比有很大提升。
结合权重共享,每层总的参数量为3hd ^ 2+2hd;而原始Transformer至少需要4hd^2.
以上分析充分证明了Fastformer的高效。

Section IV Experiments

Part 1 Datasets

本文在5个数据集上验证了Fastformer的有效性


训练卡:V100 32GB


每个实验重复5次取5次结果的平均值

在这里插入图片描述

Part 2 Effective Comparisons

与Fastformer对比的有:


Transformer,Longformer(sparse attention),BigBird(sparse random attention),Linformer(low-dimension k,v),Linear Transformer,Poolingformer(shift window+pooling SA)

在这里插入图片描述

Table 4展示了在不同任务上的对比结果,可以看到改进过的Transfomrer均比原Transformer性能优异,这是因为原始Transformer受限于SA的计算复杂度,限制了可以处理的最大序列长度,在截断输入本文时丢失了许多有用的上下文。

而Fastformer与其他变体相比,在建模长程和短程上下文关系时具有更好的性能,这是因为Fastformer建模了全局上下文与token之间的关系,有助于更好的理解上下文信息。

Table 5展示的是新闻推荐的对比结果,参与对比的网络有:NRMS,FIM,PLM-NR等。可以看到Fastfdormer性能最好,也优于NRMS模型,结合PLM-NR可以进一步提升性能,这也表明Fastformer不仅适合于文本建模,也适合于用户理解类的任务。


Table 6则展示的是文本总结的结果,可以看到在CNN.DM数据集上,一些变体效果不如原始的Transformer,这是因为这些基于稀疏的方法不能完全建模上下文信息,而基于近似的方法也不能在近似过程中有效的考虑上下文信息。因此当序列长度相同时,性能还不如原始的Transformer。而Fastformer在大多数指标上均是最优的,也进一步证明了Fastformer在自然语言生成方面的优势。
在这里插入图片描述

Part 3 Efficiency Comparison

Table 7展示的是不同Transformer变体的计算复杂度,其中N为序列长度,K为平均token的数目,d为隐藏层的维度。但是Bigbird和Longformer的复杂度均取决于每个token参考的token的平均数量,Linear Transformer的复杂度则与隐藏层维度有关,Poolingformer取决于window的尺寸。而Fastformer只取决于序列长度和h,是复杂度最小的。



在这里插入图片描述

在这里插入图片描述
此外,本文还测试了在不同序列长度,不同batchsize上的表现,序列长度跨度从128到65535.Fig 2是实验结果。可以看到Fastformer即使在输入序列很长时间推理时间也较短,而且本文还发现虽然Poolingformer声称其具有线性复杂度,但在实际应用中并不高效。因为它使用了较大的window size来计算池化权重,这样会显著增加计算成本。

Part 4 Influence of Interaction Function

本文还探究了使用不同方法建模交互作用,主要是addition,concatenation,以及element-wise product,对比结果参见Fig 3.
在这里插入图片描述

可以看到concatenation并不是最优选择,因为仅仅将两个向量级联,不能很好的表示他们之间的交互关系;add稍好一些,但也只能建模两个向量之间的线性交互,并不足以学习精确的上下文表征。而element-wise producy可以有效的建模非线性关系,有助于模型捕获长序列中复杂的上下文信息。

Part 5 Influence of Parameter Sharing

本文还探究了不同权重共享方式对精度的影响,包括:共享Q,V 矩阵参数,注意力头之间的参数共享,以及参数的跨层共享。


在这里插入图片描述

Fig 4是实验结果,可以看到只是用Q,V共享会比原始Fastformer有一定心梗提升,同时可以有效减少参数量;但是head共享会带来精度的下降,因为使用不同的注意力头是希望可以捕获不同的上下文模式,因此共享他们的参数对上下文建模没有什么好处;本文还发现跨层共享可以进一步提升模型性能。因此本文最终选择用QV参数共享和跨层参数共享。

Section V Conclusion and Future Work

本文提出的Fastformer是一种基于加性注意力的Transforemr模型,可以有效的在线性时间复杂度内捕获长程关系,借助加性注意力在SA的计算过程汇总使用的是总结后的全局query,key等信息。在5类数据集上的实验结果表明Fastformer可以在文本建模中保持较高精度的同时计算更加高效。

未来本文预计使用预训练后的结果进一步提升Fastformer对较长文本的建模能力,此外,本文还将探索Fastformer在其他应用中的表现,如电商推荐、广告预测等。

发表评论:

Copyright Your WebSite.Some Rights Reserved.