微信扫一扫,关注公众号

  • 科技行者

  • 算力行者

见证连接与计算的「力量」

首页 对角线批处理技术:突破循环记忆Transformer模型在长上下文处理中的并行瓶颈

对角线批处理技术:突破循环记忆Transformer模型在长上下文处理中的并行瓶颈

2025-06-09 16:57
分享至:
----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.-
2025-06-09 16:57 科技行者

近日,来自俄罗斯AIRI、Skoltech、MIPT和MBZUAI等研究机构的研究团队发表了一项重要研究成果,由Danil Sivtsov、Ivan Rodkin、Gleb Kuzmin、Yuri Kuratov和Ivan Oseledets共同完成。这篇题为《对角线批处理技术:突破循环记忆Transformer模型在长上下文处理中的并行瓶颈》(Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts)的论文于2025年6月5日在arXiv上发布(arXiv:2506.05229v1)。

想象一下,如果你曾经尝试使用大型语言模型处理一本小说或长篇报告这样的长文本,你可能会遇到两个主要问题:一是处理速度慢得令人沮丧,二是模型会因为内存不足而崩溃。这正是目前Transformer模型面临的主要挑战 - 它们在处理长文本时需要消耗平方级的计算资源和线性增长的内存空间。

循环记忆Transformer(RMT)模型提供了一个聪明的解决方案:它们将长文本切分成较小的片段,并通过特殊的"记忆"机制将信息从一个片段传递到下一个片段。就像我们阅读一本长篇小说时会记住前面章节的重要情节,然后带着这些记忆继续阅读后面的章节一样。这种方法将计算复杂度从平方级降低到线性级,内存使用从线性增长降低到恒定大小。

然而,这种循环记忆方法也带来了新的问题:由于每个片段的处理都依赖于前一个片段的结果,模型被迫按顺序处理所有内容,无法充分利用现代GPU的并行计算能力。这就像一群厨师被迫排成一队,一个接一个地完成烹饪任务,而不是同时在厨房的不同区域工作一样,极大地降低了效率。

研究团队提出的"对角线批处理"技术巧妙地解决了这个问题。他们发现,虽然片段之间存在依赖关系,但通过重新安排计算顺序,可以在不破坏这些依赖关系的情况下实现部分并行计算。简单来说,这种方法就像是让厨师们采用错开的工作时间,在保证菜品准备顺序的同时,最大化厨房的使用效率。

应用于LLaMA-1B ARMT模型上,对角线批处理技术在处理131,072个标记(相当于一本中等长度的书籍)的长文本时,比标准的LLaMA-1B模型快3.3倍,比顺序执行的RMT实现快1.8倍。更令人印象深刻的是,在内存使用方面,ARMT模型比原始模型节省了惊人的167.1倍空间。

这项研究的重要性不仅在于提高了处理速度,还在于它不需要对现有模型进行任何重新训练 - 这是一种纯运行时的计算重排技术,可以直接应用于现有的RMT模型。对于需要处理长文本的实际应用,如文档分析、书籍理解或长对话历史跟踪,这项技术提供了一个实用且高效的解决方案。

一、对角线批处理:让模型像协调舞蹈一样高效工作

传统的Transformer模型在处理长文本时面临着严峻的挑战。它们的计算成本随着文本长度的平方增长,内存需求也随着文本长度线性增长。这就像是一个阅读者需要在阅读每个新句子时回顾所有之前读过的内容 - 当文章变得非常长时,这种方法变得极其低效且消耗大量资源。

工程师们已经提出了许多优化方案来解决这些问题。比如FlashAttention和xFormers库通过减少内存访问开销来提高吞吐量;多查询注意力(MQA)、分组查询注意力(GQA)和多头潜在注意力(MLA)等技术通过共享和优化KV缓存来降低GPU内存使用;而Ring Attention和Microsoft DeepSpeed的Ulysses则尝试将序列数据分布在多个设备上,以突破单个GPU的内存限制。

除了这些工程优化,研究人员还探索了替代标准Transformer的架构。线性循环模型如S4、RWKV、RetNet和Mamba用替代注意力机制的读写操作取代了标准的softmax注意力。这些模型像Transformer一样支持高效的并行训练,同时在推理过程中又像RNN一样只需要恒定的内存。然而,这些方法通常会降低模型的记忆容量和读写操作的准确性。

记忆增强模型,特别是具有段级循环的记忆增强Transformer,提供了另一种解决方案。这些模型将历史信息压缩到固定大小的记忆状态中,并在段之间传递这些状态。在循环记忆Transformer(RMT)中,特殊的记忆标记在段之间携带状态信息,每个Transformer块就像一个循环单元。这种方法将推理复杂度降低到线性时间和恒定内存,能够支持任意长的上下文。然而,RMT的循环特性使其难以充分并行化:所有后续层都具有循环依赖关系,所有段都必须按顺序处理。

并行循环记忆Transformer(PRMT)是一类更广泛的架构,其中每一层都维护自己的记忆状态。PRMT在层内本地化循环,消除了所有层间记忆流动。联想循环记忆Transformer(ARMT)属于这个家族,并展示了卓越的可扩展性,它能够在长达5000万个标记的序列上保持高质量,远超RMT和Mamba的能力。RWKV、Mamba和其他线性循环架构也可以被视为PRMT家族的成员,因为它们都采用了层级记忆设计。

然而,实际上,这些方法仅在单个段内利用并行性,而这种并行性受到RAM和计算限制。因此,在处理极长序列时,这些方法往往退化为按顺序处理段,甚至退化到标记级的循环处理,没有解决真正的段间并行问题。

本研究提出的对角线批处理技术是一种调度方案,可以在不改变PRMT精确循环特性的情况下,解锁其推理过程中的段间并行性。通过将层和段计算的二维网格重新组织为独立的"对角线",该方法能够在每次GPU内核启动时并发执行多达N_Layers个操作。对角线批处理完全封装了跨段的Transformer块计算,从而消除了之前RMT实现中存在的层级和段级同步障碍。

研究团队在ARMT框架中实现了对角线批处理,并在NVIDIA A100/H100 GPU上对LLaMA-1B、3B和8B模型进行了评估,序列长度最长达到131,072个标记。实验结果表明,对于1B模型,与标准全注意力推理相比,对角线批处理实现了3.3倍的加速,与顺序ARMT基线相比,实现了1.8倍的改进。这些结果证明,对角线批处理是一种在极长上下文上进行精确、线性时间推理的实用解决方案。

二、解开循环记忆Transformer的并行潜力

为了理解对角线批处理的工作原理,我们需要先了解循环记忆Transformer的基本架构。循环记忆Transformer(RMT)通过引入段级循环来扩展标准Transformer架构。具体来说,对应于段s的隐藏表示依赖于从前一个段s-1传播的循环状态M(即所谓的记忆)。

在原始RMT公式中,记忆状态被实现为一系列嵌入向量。记忆更新机制可以正式表示为:[_, _, Ms] = Transformer([Ms–1, Hs–1, Ms–1]),其中Ms表示与段s相关的记忆状态,Hs–1表示来自段s-1的输入嵌入,方括号表示输入序列的连接。

联想循环记忆Transformer(ARMT)引入了一种并行记忆机制,设计用于支持分层记忆结构。与原始RMT不同,ARMT在不同层之间维护不同的记忆状态。这种设计通过允许每一层存储和更新自己的记忆,实现了更具表现力的记忆表示。

ARMT中的记忆更新规则通过一系列复杂的数学公式实现,本质上是实现了带有delta规则的准线性注意力,用于段级循环。这种机制使ARMT能够在保持计算效率的同时处理极长的序列。

对角线批处理方法主要适用于层级循环架构,其中每个段(时间步)的输出仅依赖于同一层中前一段(时间步)的输入和输出。这类模型被广泛称为并行循环记忆Transformer(PRMT)。

在ARMT中,每一层l都有自己的记忆状态,由联想矩阵Al组成。记忆状态通过特殊的联想块更新,该块接收前一段t-1的Transformer层输出Hl t-1作为输入。这种每层记忆允许我们优化哪些段可以并行计算以及在哪些层进行计算的调度。

对角线批处理的核心直觉来自对依赖图的分析。在朴素方法中,我们必须执行大量前向操作(n_segments × n_layers),每个操作处理形状为(segment_size, hidden_size)的输入。

由于并行记忆的使用,每个(segment, layer)对只依赖于前面的对:(segment, layer-1)和(segment-1, layer)。鉴于这种依赖关系,所有segment + layer = i的对可以在第i次迭代中并行计算。每次迭代可以被可视化为前向传递计算图中的一条对角线。

如果执行不受计算能力限制,这种对角线执行方法可以带来显著的加速。需要注意的是,这种属性仅适用于并行记忆模型。在递归记忆模型中,每个(segment, layer)依赖于所有先前的(segment-k, layer-n)对,使得对角线批处理不适用。

通过这种方式,对角线批处理技术将n_layers × n_segments个顺序操作减少到n_layers + n_segments个分组计算,大大提高了处理效率。

三、技术实现与性能突破

对角线批处理技术的实现需要对模型架构进行一些修改。研究团队使用ARMT框架作为基础,将所有层替换为单个分组层。具体来说,他们对基本模型架构进行了以下调整:

首先,将线性层替换为GroupedMatmul操作。权重和偏置是通过堆叠原始层的权重和偏置构建的。这就像是将多位厨师的菜谱合并到一个大食谱中,让他们可以同时按照这个统一的食谱工作,而不是每个人依次使用各自的食谱。

其次,层归一化权重也通过在所有层之间堆叠参数来替换。此外,前向传递经过调整以确保正确的广播行为。这相当于确保所有厨师使用统一的测量标准和工具,以便协调工作。

最后,所有其他操作保持不变,但它们操作的方式就像处理大得多的批次大小一样,从而促进并行执行。就像厨师们虽然各自工作,但都在同一个大厨房里协同操作,共享设备和空间。

对于分组矩阵乘法,研究团队使用了CUTLASS库中的GroupedGEMM函数,并进行了一个小优化:输出张量被预先分配为单个大张量,然后在不增加额外开销的情况下被分割成单独的子矩阵。

在实验部分,研究团队评估了对角线批处理方法在单个请求推理和批处理策略方面的性能。实验使用了Llama-3系列模型进行,包括160M、1B、3B和8B参数大小的模型变体。

首先,研究团队分析了网络内瓶颈操作的效率提升情况。对于线性层,他们发现分组GEMM的FLOPS(每秒浮点运算次数)随着组大小的增加而增长,类似于相应批次大小的GEMM。这为他们的方法与底层模型批次大小扩展的相似性提供了基础。他们将组大小设置为模型中的层数,使分组GEMM操作达到a100和h100 GPU的峰值GEMM flops,确保高利用率。

对于注意力层,研究团队没有修改任何内容,而是让注意力层执行批处理操作,批次大小等于层数。这将其性能提升到实现FLOPS峰值。

这些单个操作的性能提升直接转化为整体模型加速。在所有模型大小和批次配置中,他们的实现始终比默认的ARMT实现实现了显著的加速。对于较小的段大小,增益尤为明显。这是因为,对于较大的矩阵乘法,硬件利用率已接近峰值FLOPS,留给组扩展的空间较少。

这些结果的一个关键含义是,研究人员可以优先考虑基于质量的段大小选择,而不必过于受性能限制。对角线批处理将性能与段大小解耦,为架构决策提供了更大的灵活性。

在对角线批处理与小批量处理的比较中,研究团队在相同的硬件和模型配置下测量了每段的计算时间。结果表明,对角线批处理在几乎所有测试场景中都实现了与微批处理相匹配的每段计算扩展。

为了提供可实现性能的上限,他们还报告了理想均匀负载情况,即所有段计算都使用具有最大可实现FLOPS的完整分组层进行计算。可以看到,这种均匀负载设置要好得多,基本上匹配或超过最大批次大小的性能。它们之间的差距是当前实现的效率低下之处。

值得注意的是,对角线批处理为较大的模型(从1B参数开始)提供了显著的性能改进,特别是当段大小适中时。对于这些配置,对角线批处理匹配大批次大小的性能。

这些发现表明,对角线批处理有效地捕获了大批次推理的利用率优势——通过并行化调度而非增加内存分配。

四、误差积累与实际应用

研究团队还对对角线批处理在推理阶段的误差积累进行了实证研究。他们的实验表明,对于短于32,768个标记的所有序列,总体误差小于2%,这与生产中使用的其他高效层实现相当。例如,他们观察到FlashAttention2与其他注意力实现相比,在相同的随机输入序列上产生1-2%的相对logits误差。

误差的详细值表明,随着段数的增加,误差会逐渐积累,但不会超过2%的阈值。然而,误差积累对下游任务的影响可以忽略不计。为了证明这一点,研究团队在BABILong基准测试上评估了训练好的ARMT模型,结果表明,原始实现和使用对角线批处理的实现在BABILong基准测试上达到了相同的结果。

对于64k长度的标记序列,对角线批处理可以将相对速度提高3.2倍,显著提升了处理效率。这意味着在实际应用中,对角线批处理可以在保持模型性能的同时,大大缩短处理长文本的时间。

研究团队还实现了对角线批处理的反向传播,以支持训练过程。通过对齐训练和推理代码,消除了可能导致logits级浮点漂移的差异。这进一步确保了对角线批处理技术在全流程应用中的稳定性和可靠性。

在实际应用方面,对角线批处理技术对不同大小的模型均显示出显著的性能提升。对于Llama-160M模型,在处理131,072个标记的序列时,对角线批处理比基本ARMT快3.9倍;对于Llama-1B模型,快2.7倍;对于Llama-3B模型,快1.3倍;对于Llama-8B模型,快1.14倍。

这些结果表明,对角线批处理技术在各种模型规模下都能有效提升性能,特别是对于中小型模型,提升更为显著。对于需要处理长文本但计算资源有限的应用场景,这项技术提供了一个实用的解决方案。

总的来说,对角线批处理是一种通过重新安排计算顺序来提高循环记忆Transformer推理效率的创新技术。它不需要对模型进行重新训练,可以无缝集成到现有系统中,并在保持高精度的同时显著提高处理长文本的速度和效率。

五、结论与未来展望

Transformer模型在处理长上下文推理时仍然面临着计算复杂度平方增长和内存需求线性增长的挑战。虽然Mamba、RWKV和循环记忆Transformer(RMT)等线性复杂度架构试图解决这些问题,但RMT特别有吸引力,因为它只需对现有架构进行最小的修改,确保与现有模型和算法的兼容性。

本文证明,RMT及其层记忆变体(PRMT)的主要瓶颈不是算法复杂度,而是调度问题:循环依赖迫使细粒度同步,导致现代加速器利用不足。研究团队提出的对角线批处理方法通过将层-段计算网格重新组织为有利于并发的对角线,解决了这一问题,从而使每个内核能够处理多达N_Layers个操作,而不改变精确的循环特性。

实验结果表明,配备对角线批处理的Llama-1B ARMT在处理131,072个标记的上下文任务时,比普通Llama-1B快3.3倍,比顺序RMT实现快1.8倍,同时保持了结果的高精确度(相对误差仅1%)。

考虑到这些优势,对角线批处理将PRMT理论上吸引人的计算扩展转变为在极长上下文上进行精确线性时间推理的实用解决方案。通过消除主要的性能障碍,它使记忆增强的循环Transformer成为下一代LLM应用的有竞争力且可扩展的基础,这些应用需要高效的长范围输入处理。

然而,尽管具有这些优势,对角线批处理也存在一些实际限制。首先,它不直接兼容具有层内循环的循环记忆Transformer(RMT)。不过,更有前途的方法是专注于并行RMT,之前的工作已经证明这种方法更为有效。其次,当前的实现假设层配置均匀。当模型采用异构层或不同的隐藏大小时,应用该技术需要更复杂的分组逻辑和手动工程。最后,可实现的加速随层数增加而增加,因此较浅的模型或层数很少的模型只会看到适度的性能提升。

未来的研究方向可能包括:进一步优化对异构架构的支持,探索在其他类型的循环神经网络上应用类似技术的可能性,以及结合其他优化技术(如量化和稀疏化)进一步提高性能。随着大型语言模型在各种应用领域的普及,高效处理长上下文的能力将变得越来越重要,对角线批处理技术为解决这一挑战提供了一个有前途的方向。

分享至
0赞

好文章,需要你的鼓励

推荐文章
----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.-