2024-05-08 挑战Transformer:全新架构Mamba详解

背景

屹立不倒的 Transformer 迎来了一个强劲竞争者。

自 2017 年被提出以来,Transformer 已经成为 AI 大模型的主流架构,但随着模型规模的扩展和需要处理的序列不断变长,Transformer 的局限性也逐渐凸显。一个很明显的缺陷是:Transformer 模型中自注意力机制的计算量会随着上下文长度的增加呈平方级增长,比如上下文增加 32 倍时,计算量可能会增长 1000 倍,计算效率非常低。

为了克服这些缺陷,研究者们开发出了很多注意力机制的高效变体,但这往往以牺牲其有效性特为代价。到目前为止,这些变体都还没有被证明能在不同领域发挥有效作用。

而就在最近,一名为 Mamba 的架构似乎打破了这一局面。

与类似规模的 Transformer 相比,Mamba 具有 5 倍的吞吐量,而且 Mamba-3B 的效果与两倍于其规模的 Transformer 相当。性能高、效果好,Mamba 成为新的研究热点。

图1 Mamba 在推理过程中的吞吐量对比

本文将详细的解读 Mamba 架构,由于 Mamba 是基于 SSM->HiPPO->S4->Mamba 演化过来的,而 HiPPO、S4、Mamba 的一作者都是卡内基梅隆大学机器学习系助理教授 Albert Gu。因此,本文将从标准 SSM 开始,逐步介绍 HiPPO、S4、Mamba。

图2总结了SSM、HiPPO、S4、Mamba的主要区别,以及各个模型的主要内容。本文内容也将按图中内容展开。

图2-2:HiPPO、S4、Mamba

一、现有架构问题

序列建模的核心问题是:同时解决有效高效。有效是指能够选择性记忆历史信息,解决长距离依赖(Long-Range Dependencies,LRDs)问题;高效是指计算高效。

尽管传统的模型如循环神经网络(RNNs)、卷积神经网络(CNNs)和 Transformers 在处理长距离依赖方面有专门的变体,但它们在处理超过 10000 步的极长序列时仍然面临挑战。

1.1 Transformer 问题

Transformer 的一个主要优点是,无论它接收到多长的输入,它都使用序列中的所有 token 信息(无论序列有多长)来对输入数据进行处理。

图1-1:Transformer会查看过去所有 token

但是为了获得全局信息,注意力机制在长序列上非常耗费显存。注意力创建一个矩阵,将每个 token 与之前的每个 token 进行比较。矩阵中的权重由 token 对之间的相关性决定。

图1-2:Transformer 会计算每个 token 之间的 Attention

在训练过程中,Attention 计算可以并行化,所以可以极大地加快训练速度。但是在推理过程中,当生成下一个 token 时,我们需要重新计算整个序列的注意力。

图1-3:生成新 token 时需要重新计算整个序列的注意力

长度为 L 的序列生成 token 大约需要 L² 的计算量,如果序列长度增加,计算量会平方级增长。因此,需要重新计算整个序列是 Transformer 体系结构的主要瓶颈。

图1-4:Transformer 训练快、推理慢

1.2 RNN 的问题

图1-5:循环神经网络 RNN

在生成输出时,RNN 只需要考虑之前的隐藏状态和当前的输入。这样不会重新计算以前的隐藏状态,这正Transformer 不具备的。

这种结构可以让 RNN 进行快速推理,并且理论上可以无限扩展上下文长度,因为每次推理只取一个隐藏状态和当前输入,内存占用非常稳定。

RNN 的每个隐藏状态都是之前所有隐藏状态的聚合。但是这里会有一个问题,在生成 token “Liang” 时,最后一个隐藏状态不再包含关于 token “Hello” 的信息。这会导致随着时间的推移,RNN 会忘记更久的信息,因为它只考虑前一个状态。

图1-6:只考虑前一个 hidden state

并且 RNN 的这种顺序性产生了另一个问题。训练不能并行进行,因为它需要按顺序完成每一步。

图1-7:RNN 训练不能并行

RNN的统一定义为:

可以看到,当前梯度依赖上个 token 的梯度。

与 Transformer 相比,RNN 的问题完全相反!它的推理速度非常快,但不能并行化导致训练很慢。

图1-8:RNN 和 Transformer对比

人们一直在寻找一种既能像 Transformer 那样并行化训练,能够记住先前的信息,又能在推理时时间是随序列长度线性增长的模型,Mamba 就是这样应运而生的。解下来我们从 SSM 开始,逐步介绍 Mamaba。

二、状态空间模型 SSM

2.1 什么是 SSM

2.2 SSM 架构

下图是 SSM 的架构,主要包含两个部分:状态更新方程和输出方程。

图2-1:SSM结构

SSM 可以简化为以下结构:

图2-2:简化的SSM结构

下面我们看一下更详细的结构,首先是状态更新,如下所示:

图2-3:状态更新详细结构

然后是输出方程,详细机构如下所示:

图2-4:输出方程详细结构

2.3 SSM 例子:弹簧振子

下面举一个描述弹簧振子系统的 SSM 例子。

图2-5:弹簧振子

考虑一个质量为m的物体,它连接在一个劲度系数为k的弹簧上,并且受到阻尼系数为c的阻尼力作用。当物体从平衡位置偏离时,它会在弹簧力的作用下进行振动。我们可以用状态空间模型来描述这个系统的动态。

状态变量可以选择为物体的位移s(t)和速度v(t)。输入u(t)在这个例子中可以为零,因为我们没有外部力作用在物体上。输出y(t)可以是我们感兴趣的位移s(t)。

状态向量定义为:

输入向量为:

输出位移s(t)。弹簧振子的状态空间方程可以表示为:

在了解 SSM 基本概念之后,接下来我们介绍基于 SSM 的 HiPPO 架构。

三、HiPPO(High-order Polynomial Projection Operators)

HiPPO 是 Albert Gu 于2020年在论文 HiPPO: Recurrent Memory with Optimal Polynomial Projections 中提出的新架构。HiPPO 主要为了解决如何在有限的存储空间中有效地解决序列建模的长距离依赖问题。

HiPPO 通过函数逼近产生状态矩阵 A 的最优解,有效的解决了长距离依赖问题。

问题背景:在处理序列数据时,一个核心问题是如何在增量方式下表示累积的历史信息。这涉及到如何在有限的存储空间中有效地更新和维护历史数据的表示。

HiPPO框架:作者介绍了一个名为 HiPPO(High-order Polynomial Projection Operators)的通用框架,它通过将连续信号和离散时间序列投影到多项式基上,实现了在线数据压缩。

重要性度量:HiPPO 框架考虑了一个度量,用于指定过去每个时间步的重要性。这个度量帮助HiPPO产生在线函数逼近问题的最优解。

理论贡献:HiPPO 框架不仅提供了对现有记忆单元的简短推导,还推广了循环神经网络(如GRUs)中普遍存在的门控机制。

新的记忆更新机制:作者提出了一个新的记忆更新机制(HiPPO-LegS),它能够随时间扩展以记住所有历史信息,避免了对时间尺度的先验假设。

理论优势:HiPPO-LegS 具有时间尺度鲁棒性、快速更新和有界梯度的理论优势。

实验结果:在基准测试中,HiPPO-LegS 在打乱的 MNIST 数据集上达到了98.3%的新最佳准确率。在一个新的轨迹分类任务中,HiPPO-LegS 在处理分布外时间尺度和缺失数据方面,比其他 RNN 和神经 ODE(一阶常微分方程)基线模型的性能提高了25-40%的准确率。

下面介绍 HiPPO 实现的具体细节。

3.1 HiPPO 架构:高阶多项式投影

3.1.1 HiPPO问题设置

问题定义

由于函数空间的庞大,无法完美记住整个历史,因此需要将其进行压缩,HiPPO 提出了将历史投影到有界维数的子空间的一半方法。

函数逼近与度量

多项式基展开

任何 N 维的函数子空间 G 都是逼近的合适候选。参数 N 对应于逼近的阶数,或者说压缩的大小;投影的历史可以通过G的任何基的N个系数来表示。

论文中使用多项式作为自然基,因此G是小于N阶的多项式的集合。

在线逼近

挑战

3.1.2 HiPPO 通用架构

通过连续动态系统计算投影

这部分是 HiPPO 的关键步骤,它涉及到将输入函数u(t)在时间 t 投影到一个多项式空间上,以便在线更新记忆表示。

在线更新:通过这个连续动态系统,HiPPO 框架能够在线更新记忆表示,即随着新数据的到来,系统能实时地调整系数x(t)。

在线函数逼近

图3-1:HiPPO框架

3.1.3 高阶投影:度量方法以及 HiPPO 动态系统

作者定义了两种度量方法,分别是 LegT 和 LagT。LegT 度量为最近的历史信息分配均匀的权重,表示如下:

LagT 度量使用指数衰减的方式来衡量历史信息的重要性,表示如下:

对于 LegT 和 LagT,系数x(t)可以使用 ODE(一阶常微分方程)来表示:

备注:公式(9)是 HiPPO 框架的关键部分,具体推导可以参看论文中的附录 D。

对于 LegT 度量,矩阵 A 和 矩阵 B 可以表示如下:

对于 LagT 度量,可以表示如下:

3.1.4 HiPPO 框架中的连续时间动态转换为离散时间递归关系

由于我们处理的输入往往是离散的,因此我们需要将公式(9)的 ODE 离散化。ODE 离散化是一种常用的数据技术,它将连续时间的常微分方程转换为离散时间的差分方程。这通常涉及到选择一个合适的时间步长(或步长Δt),并使用数值方法(如欧拉方法、双线性)来近似连续微分。

图3-2:连续信号离散化

使用双线性离散化,如下所示:

结合公式(9)和公式(11),我们可以得到离散化的状态更新公式,表示如下:

离散化之后的 SSM 结构可以表示如下:

图3-3:离散化 SSM

图3-4:每个时间步的计算

这种表示看起来是不是有点熟悉?其实他的处理方法和RNN一样。

图3-5:离散化后和RNN类似

3.2 HiPPO-LegS

HiPPO-LegS 是作者基于新的度量提出的全新架构,具有时间鲁棒性、有界梯度、有界近似误差、长时间记忆等效果。

具体推导在论文的附录 D.3 部分。

更好的学习长期依赖

HiPPO-LegS 是专门为记忆而设计的,它通过其独特的结构和更新机制来避免梯度消失问题。LegS 通过使用Legendre 多项式作为基函数,并结合时间尺度不变的度量,来保持梯度的稳定性。

这个性质使得 HiPPO-LegS 能够有效地缓解 RNN 中的梯度消失问题。即使在长序列中,梯度也不会迅速衰减到0,这有助于网络在训练中更好地学习长期依赖。

近似有界误差

3.3 实验

模型结构如下:

下面是pMINIST 数据集上的结果,可以看到 LegS 的效果要好于 LagT 和 LegT,同时 HiPPO 的效果好于之前的其它模型。

备注:pMNIST(permuted MNIST)是一个经过修改的MNIST数据集,它用于测试和评估机器学习模型在处理序列数据和学习长期依赖关系方面的能力。在 pMNIST 中,原始 MNIST 图像的像素被重新排列。这意味着图像的像素不再是按照自然顺序(从左到右,从上到下)呈现,而是按照一个固定的、随机的排列顺序。这种排列方式使得模型必须学习像素之间的长期依赖关系,而不能简单地依赖于局部空间结构。

四、S4 (Structured State Space Model)

S4 是 HiPPO 的后续工作,论文名称为:Efficiently Modeling Long Sequences with Structured State Spaces。

S4 的主要工作是将 HiPPO 中的矩阵 A(称为 HiPPO 矩阵)转换为正规矩阵(正规矩阵可以分解为对角矩阵)和低秩矩阵的和,以此提高计算效率。

S4 通过这种分解,将计算复杂度降低到了0(N+L),其中 N 是 HiPPO 矩阵的维度,L 是序列长度。

在处理长度为 16000 的序列的语音分类任务中,S4 模型将专门设计的语音卷积神经网络(Speech CNNs)的测试错误率降低了一半,达到了1.7%。相比之下,所有的循环神经网络(RNN)和 Transformer 基线模型都无法学习,错误率均在70%以上。

下面我们就来介绍一下这篇工作。

4.1 HiPPO 解决了长期依赖

作者讨论了如何处理长距离依赖(Long-Range Dependencies,LRDs)的问题,LRDs 是序列建模中的一个关键挑战,因为它们涉及到在序列中跨越大量时间步的依赖关系。

作者指出,基本的 SSM 在实际应用中表现不佳,特别是在处理 LRDs 时。这是因为线性一阶常微分方程(ODEs)的解通常是指数函数,这可能导致梯度在序列长度上呈指数级增长,从而引发梯度消失或爆炸的问题。

为了解决这个问题,作者利用了 HiPPO 理论。HiPPO 理论指定了一类特殊的矩阵 A,当这些矩阵被纳入 SSM 的方程中时,可以使状态 x(t) 能够记住输入 u(t) 的历史信息。这些特殊矩阵被称为 HiPPO 矩阵,它们具有特定的数学形式,可以有效地捕捉长期依赖关系。

HiPPO 矩阵的一个关键特性是它们允许 SSM 在数学和实证上捕捉 LRDs。例如,通过将随机矩阵 A 替换为 HiPPO 矩阵,可以在序列 MNIST 基准测试上显著提高 SSM 的性能。

HiPPO 矩阵表示如下:

4.2 在线推理:使用递归形式

S4 在推理时,使用公式(12)的递归形式,每次只需要和上一个状态进行计算,具有和 RNN 相似的推理效率。

4.3 训练 S4:卷积表示

由于离散时间 SSM 的递归性质,它在硬件上进行训练时存在效率问题。因此,作者将离散时间 SSM 的递归方程转换为离散卷积的形式。通过展开递归方程,可以得到一个卷积核,这个卷积核可以用来在序列数据上应用卷积操作。这种转换允许 SSM 利用快速傅里叶变换(FFT)等高效的卷积计算方法,从而在训练过程中提高计算效率。

上面式子可以转化为卷积的形式:

作者在这一节中还讨论了如何计算 SSMn卷积核,这是他们技术贡献的关键部分。通过这种卷积表示,SSM 可以被有效地训练,同时保持其在处理长距离依赖(LRDs)方面的能力。这种表示形式为 SSM 在各种序列建模任务中的应用提供了灵活性,包括图像处理、语音识别和时间序列分析等。

图4-1:SSM 卷积核形式

下面是一个具体的例子,如何使用卷积核生成输出。

图4-2:使用卷积核生成输出

卷积的一个主要好处是它可以并行训练。但是由于核大小是固定,它们的推理不如 RNN 快速并且对序列长度有限制。

图4-3:递归 SSM 和 卷积 SSM 的对比

这里可以使用一个简单的技巧,即根据任务选择表示。在训练过程中使用可以并行化的卷积表示,在推理过程中,我们使用高效的循环表示。

图4-4:递归推理、卷积训练

4.4 为什么对角化可以减少 SSM 计算复杂度

为了进一步提升计算效率,作者讨论了对角化在计算离散时间状态空间模型(SSM)中的应用,以及为什么直接应用对角化方法在实践中并不可行。

对角化是一种线性代数技术,它可以将一个矩阵转换为对角形式,从而简化矩阵的乘法和其他运算。在 SSM 的上下文中,对角化可以显著减少计算复杂度,因为对角矩阵的幂运算(如在递归方程中出现的)可以通过简单的元素指数运算来完成。

下面我们解释下,为什么对角化可以减少 SSM 计算复杂度。

首先,我们引入论文中的定理 3.1

(Lemma 3.1):共轭是 SSM 中的等价关系,即:

第一个 SSM:

第二个 SSM:

通过V将第二 SSM 乘以V后,变成如下形式:

4.5 直接对角化 HiPPO 矩阵导致数值溢出

下面证明下这个结论。

首先我们可以找到矩阵A的一个相似矩阵,表示如下:

其中:

那么可以找到一个可逆矩阵:

即然无法直接对A矩阵进行对角化,那么是否可以将其转化为低秩矩阵或者其它可以对角化的矩阵?解下来我们介绍下如何将其转换为正规矩阵+低秩矩阵。

4.6 S4 参数化:正规矩阵+低秩矩阵

虽然矩阵A不能直接对角化,但是可以表示为正规矩阵+低秩矩阵。

Theorem 1:HiPPO 矩阵A可以表示为正规矩阵+低秩矩阵的形式,即:

下面简单证明下这个定理。

已知 HiPPO 矩阵A可以表示为:

这样我们就将矩阵A转换为了正规矩阵+低秩矩阵的形式。下面我们看一下转换之后的递归计算和卷积计算的复杂度。

4.7 S4 的计算复杂度

经过正规矩阵+低秩矩阵分解后,我们再来考虑 S4 的计算复杂度有什么变化。我们同时考虑推理时递归计算的复杂度以及训练时卷积计算的复杂度。

先给出结论:

这里的 Cauchy 矩阵-向量乘法复杂度表示如下:

如果Cauchy 矩阵-向量乘法按照精确计算,那么 S4 的卷积复杂度为

解下来我们详细介绍 S4 计算复杂度的分析过程,首先介绍递归计算复杂度。

递归计算复杂度

由于:

现在,我们可以将公式(30)重新表示为下面的形式:

公式(12)的 SSM 可以重新表示为:

卷积计算复杂度

这一块就不再具体介绍了,感兴趣的可以直接去看原论文,在论文的附录 C.3 有详细的分析过程。

最后作者对比了 S4 和原始卷积、递归、Attention 之间的计算复杂度,可以看到 S4 是最低的,如下图所示:

图4-5:计算复杂度对比

图中 L 表示序列长度,B 表示 batch size,H 表示隐藏维度。

4.8 实验结果

推理效率

序列长度为 1024 时,S4 的推理速度是 Transformer 的 1.58 倍;序列长度为 4096 时,是 Transformer 推理速度的 5.19 倍,由于不需要 KV cache,因此内存占用非常小。

S4 作为生成模型的效果

将 S4 应用在生成模型中,实验结果如下图所示。

图4-7:S4 作为生成模型的效果

HiPPO 影响

这部分作者进行了一系列的消融实验(Ablations),以评估 HiPPO 矩阵在状态空间模型(SSM)中的重要性。这些实验旨在探究 HiPPO 矩阵在 S4 模型中的作用,以及它对于模型性能的影响。

HiPPO 矩阵是 S4 模型中用于处理长距离依赖(LRDs)的关键组件。在这一节中,作者通过以下几个方面的实验来验证 HiPPO 矩阵的重要性:

  1. HiPPO 初始化:作者首先研究了不同初始化方法对 SSM 性能的影响,包括随机高斯初始化、HiPPO 初始化以及随机对角高斯矩阵初始化。实验结果表明,HiPPO 初始化在提高模型性能方面起到了关键作用。
  2. HiPPO 矩阵是否可训练:作者还探讨了 HiPPO 矩阵固定以及可训练的效果。他们发现,固定 HiPPO 和可训练的差异不大。
  3. NPLR SSMs:作者进一步研究了在没有 HiPPO 矩阵的情况下,随机 NPLR(Normal Plus Low-Rank,正规+低秩矩阵)的表现。结果表明,即使在 NPLR 形式下,这些随机矩阵的性能仍然不佳,这验证了 HiPPO 矩阵在 S4 模型中的核心作用。

通过这些消融实验,作者强调了 HiPPO 矩阵在 S4 模型中的重要性。这些实验结果不仅证实了 HiPPO 矩阵在处理长距离依赖方面的有效性,而且也表明了它在提升模型整体性能方面的关键作用。这些发现对于理解 S4 模型的设计和优化至关重要。

图4-8:HiPPO 矩阵初始化效果远远高于其它矩阵初始化

虽然 S4 在保证了计算效率的同时,优化了长距离依赖问题。但是由于矩阵ABC是固定不变的,和输入 token 无关,这就导致了 S4 在一些合成任务上效果不佳,比如选择性复制任务。

而为了解决这些问题,作者提出了 Mamba 架构,通过选择性机制改进 S4,有效解决了这类问题。下面我们就来介绍下最近很火的 Mamba 结构。

五、Mamba

我们终于介绍完了理解 Mamba 所需要的基础知识。状态空间模型可用于建模文本序列,但仍有一系列我们想要避免的缺点。

在本节中,我们将介绍 Mamba 的两大主要贡献:

  1. 一种选择性扫描算法,该算法允许模型过滤(不)相关信息;
  2. 一种硬件感知算法,该算法允许通过并行扫描、内核融合和重新计算来高效存储(中间)结果。

它们共同创建了选择性 SSM 或 S6 模型,这些模型可以像自注意力一样用于创建 Mamba 块。

在探讨这两大主要贡献之前,让我们首先探讨一下为什么它们是必要的。

状态空间模型,甚至是S4(结构化状态空间模型),在某些对语言建模和生成至关重要的任务上表现不佳,即关注或忽略特定输入的能力。

我们可以通过两个合成任务来说明这一点,即选择性复制和归纳头。

在选择性复制任务中,SSM 的目标是复制输入的部分内容并按顺序输出它们:

图5-1:选择性复制任务

然而,由于(循环/卷积)SSM 是线性时间不变的,因此在这项任务中表现不佳。正如我们之前看到的,对于 SSM生成的每个 token,矩阵 A、B 和 C 都是相同的。

因此,由于固定的 A、B 和 C 矩阵,SSM 无法执行内容感知推理,因为它对每个 token 都一视同仁。这是一个问题,因为我们希望 SSM 能对输入(提示)进行推理。

SSM 表现不佳的第二项任务是归纳头,其目标是重现输入中发现的模式:

图5-2:重现输入中发现的模式

在上面的例子中,我们本质上是在执行一次提示,我们试图“教”模型在每个“Q:”之后提供一个“A:”的回应。然而,由于 SSM 是时间不变的,它无法选择从历史中回忆哪些之前的 token。

让我们通过关注矩阵 B 来说明这一点。无论输入 u 是什么,矩阵 B 都保持不变,因此与 u 无关:

图5-3:矩阵 B 与输入 u 无关

同理,无论输入是什么,A 和 C 也不变,这就是我们上面说的静态。

图5-4:A 和 C 矩阵也和输入 u 无关

相比之下,这些任务对于 Transformer 来说相对容易,因为它们会根据输入序列动态地改变自己的注意力。它们可以选择性地“查看”或“关注”序列的不同部分。

SSM 在这些任务上的糟糕表现说明了时间不变 SSM 的潜在问题,即矩阵 A、B 和 C 的静态性质导致内容感知方面的问题。

5.1 通过选择机制改进 SSM

为了解决上面的问题,作者提出了一种新的选择性 SSM(Selective State Space Models,简称 S6 或 Mamba)。这种模型通过让 SSM 的矩阵 A、B、C 依赖于输入数据,从而实现了选择性。这意味着模型可以根据当前的输入动态地调整其状态,选择性地传播或忽略信息。

Mamba 集成了 S4 和 Transformer 的精华,一个更加高效(S4),一个更加强大(Transformer)。

图5-5:Mamba 集成了 S4 和 Transformer 各自的优点

正如上面所提到的,它是通过有选择地将数据压缩到状态中来实现的。当你有一个输入句子时,通常会有一些信息,比如停用词,没有太多意义。

为了有选择地压缩信息,我们需要让参数依赖于输入。为此,我们首先来探讨一下 SSM 在训练过程中输入和输出的维度。

图5-6:SSM 输入(u)输出(y)维度

备注:在前面的 HiPPO 和 S4 中,我们假设的输入信号u(t)是 1 维的,而实际应用中大多数都是多维的,后面我们默认是多维输入(默认维度为 D)。而且需要强调的是 S4 用的是 Single-input-single-output (SISO),即对应于每一个输入的维度,都有一套独立的 SSM 参数 (传统的 RNN 是 MIMO,multiple-input-multiple-output, 很容易混淆)。

在 S4 中,矩阵 A、B 和 C 与输入无关,因为它们的维度 N 和 D 是静态的,不会改变。

图5-7:S4 中的矩阵A、B、C

相反,Mamba 通过将输入序列的长度和批次大小结合起来,使矩阵 B 和 C,甚至步长 ∆ 都依赖于输入:

图5-8:Mamba 中的矩阵B、C和输入有关(L是序列长度)

这意味着对于每个输入 token,我们现在有不同的 B 和 C 矩阵。

备注:这里矩阵 A 保持不变,因为我们希望状态本身保持静态,但影响它的方式 (通过 B 和 C) 是动态的。

它们一起选择性地决定在隐藏状态中保留什么和忽略什么,因为它们现在依赖于输入。

在 SSM 中,通过调整 ∆,模型可以控制对当前输入的关注度,从而实现类似于 RNN 门控的效果。例如,当 ∆ 较大时,模型倾向于关注当前输入并忽略之前的信息;而当∆较小时,模型则倾向于保留更多的历史信息:

图5-9:步长 ∆ 效果相当于门控

下面我们看一下选择性 SSM 的完整过程,如下所示:

算法 2 展示了作者所使用的主要选择机制。这一套的思路由来已久,Transformers 里面的 QKV、LSTM里面的、Gating 都是类似的思想。

S4 和 选择性 SSM 的核心区别在于,它们将几个关键参数(∆, B, C)设定为输入的函数,并且伴随着整个 tensor 形状的相关变化。特别是,这些参数现在具有一个长度维度 L,这意味着模型已经从时间不变(time-invariant)转变为时间变化(time-varying)。

其中 :

5.2 选择性 SSM 和门控之间的关系

时间步∆

时间步∆和 RNN 的门控有很强的关联,依赖输入的∆跟 RNN 的遗忘门的功能类似。

可以看到这就是一个带门控的 RNN。

矩阵 B 和 C

在 SSM 中,修改 B 和 C 以使其具有选择性,允许模型更精细地控制是否让输入进入状态 h 或状态进入输出 y,所以 B 和 C 类似于 RNN 中的输入门和输出门。

矩阵 A

5.3 Mamba 高效实现

因为现在的参数ABC都是输入相关了,所以不再是线性时间不变系统,也就失去了卷积的性质,不能用 FFT来进行高效训练了。

Mamba 作者采用了一种称为硬件感知的算法,实际上就是用三种经典技术来解决这个问题:内核融合(kernel fusion)、并行扫描(parallel scan)和重计算(recomputation)

而 Mambda 作者的方法是:

图5-10:Mamba 的 scan 和其它方法对比

Mamba 的实现比其它方法实现快很多倍,scan 在输入长度 2k 的时候就开始比 FlashAttention 快了,之后越长越快。同时 scan 也比 Convolution 快。

5.4 Mamba 架构

下图是 Mamba 的模型结构:

图5-11:Mamba 模型结构

之前的 SSM 模型要 work,都会加上 output gating,之后再过个线性层 channel mixing,如上图的最左边所示。这两个部分跟 Gated MLP(上图中间)右边的支路和最上面的 channel mixing 是一样的。所以 SSM 层如果跟Gated MLP 合并的话,难免会感觉有点冗余,所以作者干脆把两个合二为一,把 token mixing 层和 channel mixing。

图 5-11 的 Mamba 可以作为一个块来实现,就像我们可以在解码器块中表示自注意力一样。

图5-12:Mamba 块

与解码器一样,我们可以堆叠多个 Mamba 块,并使用它们的输出作为下一个 Mamba 块的输入:

图5-13:多个Mamba块组合使用

它首先进行线性投影以扩展输入 embedding。然后,在应用选择性 SSM 之前进行卷积。选择性 SSM 具有以下属性:

  • 通过离散化创建递归 SSM;
  • 对矩阵 A 进行 HiPPO 初始化,以捕获远程依赖关系;
  • 选择性扫描算法,有选择地压缩信息;
  • 硬件感知算法,加速计算;

下面是一个端到端(输入到输出)的例子:

图5-14:Mamba 架构端到端输出例子

下面我们看一下 Mamba 和 Transformer 以及 RNN 的对比:

图5-15:Mamaba和Transformer以及RNN对比

5.5 实验

之前提到∆的作用类似遗忘门,而遗忘门毫无疑问是 LSTM 里面最重要的门,下面这个消融实验结果就论证了∆data dependent 影响最大。

图5-16:对不同参数data dependent的敏感性

最后再看一下模型效果:

图5-17:Mamba和Transformer对比

总结起来就是,效果最好,速度最快!

总结

融合 SSM 和 LSTM,将 LSTM 选择性的思想融入 SSM 中,全方位的实现优化,使得 Mamaba 即具备像 Transformer 高效训练的特点,又具备 S4 中支持长文本的优点,同时具备 LSTM 一样选择性记忆的特点。

实验证明了 Mamba 的优秀,但是还需要更长的时间检验,目前还没有 10B 以上的 Mamba 模型,就让时间来检验。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注


往期评论