微信扫一扫,关注公众号

  • 科技行者

  • 算力行者

见证连接与计算的「力量」

首页 StreamBP:让大语言模型训练长序列变得更轻松——香港中文大学(深圳)团队突破性研究

StreamBP:让大语言模型训练长序列变得更轻松——香港中文大学(深圳)团队突破性研究

2025-06-09 15:31
分享至:
----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.-
2025-06-09 15:31 科技行者

如果你关注人工智能和大语言模型的发展,一定听说过ChatGPT、Claude或Gemma等能够处理长篇文本的AI助手。但你可能不知道,训练这些模型处理长文本时面临一个巨大的挑战:内存不够用了!这就像是试图在一个小餐盘上准备一顿丰盛的十道菜晚宴——空间严重不足。

近日,香港中文大学(深圳)的研究团队Qijun Luo、Mengqi Li、Xiao Li以及上海交通大学的Lei Zhao在2025年6月发表了一篇题为《StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs》的研究论文(arXiv:2506.03077v1),提出了一种突破性的解决方案。这项研究已在GitHub上开源(https://github.com/Ledzy/StreamBP),任何人都可以将其整合到自己的模型训练流程中。

在深入了解这项研究之前,我们先来理解问题的本质。当大语言模型(LLM)需要学习处理长文本时,比如解决复杂的数学问题或编程任务,它们通常需要处理包含详细推理过程的超长序列。这些序列可能长达数万甚至十万个标记(tokens)。在训练过程中,尤其是反向传播(Backpropagation,简称BP)阶段,需要存储大量的中间激活值,这会占用巨量的GPU内存。就像你需要记住烹饪过程中每一步的状态才能知道哪一步出了错——当步骤太多时,你的"记忆空间"就会不堪重负。

现有的解决方法如梯度检查点(gradient checkpointing)技术虽然能节省一些内存,但仍然需要存储大量的中间值,导致训练长序列时GPU内存不足。StreamBP提出了一个优雅的解决方案:将反向传播过程沿着序列维度进行线性分解,显著减少了存储激活值和logits(模型输出前的原始预测值)所需的内存。

一、StreamBP的核心思想与工作原理

想象一下,你正在教一个孩子学习一首长诗。传统方法是让孩子一次性背诵整首诗,然后检查每个字的发音。如果发现最后一个字发音错了,你需要重新听孩子背诵整首诗来纠正。这就像传统的反向传播——需要保存整个序列的信息。

而StreamBP的方法则是将诗歌分成小段,每次只关注一小段的发音,然后逐段累积纠正。这样,你的"工作记忆"只需要存储当前这一小段的信息,大大减轻了记忆负担。

从技术角度看,StreamBP基于链式法则的线性分解。在模型训练中,我们需要计算损失函数对模型参数的梯度。假设有一个转换函数fW(Zin) = Zout,其中W是与转换相关的权重。通过链式法则,我们可以计算损失L对权重W的梯度:?L/?vec(W) = (?vec(Zout)/?vec(W))^T · (?L/?vec(Zout))。

传统方法需要一次性计算和存储所有中间激活值。而StreamBP的创新在于,它将vec(Zout)分解为多个部分[vec(Z^(1)_out), vec(Z^(2)_out), ..., vec(Z^(D)_out)],然后利用链式法则的线性分解性质,将梯度计算分解为这些部分的和:

?L/?vec(W) = Σ(i=1到D) (?vec(Z^(i)_out)/?vec(W))^T · (?L/?vec(Z^(i)_out))

通过这种分解,计算每个部分所需的内存大大减少,因为我们只需要存储当前部分的激活值,而不是整个序列的激活值。这就像我们只需要记住诗歌的一小段,而不是整首诗。

二、StreamBP如何应用于大语言模型的各个组件

大语言模型主要由两个关键组件构成:语言建模头(language modeling head)和Transformer层。StreamBP巧妙地对这两个组件进行了处理,大幅降低了内存使用量。

首先,让我们看语言建模头。它执行一个简单的线性变换:H·Wlm_head = logits,其中H是最后一个Transformer层的输出,Wlm_head是语言建模头的权重,logits是模型的原始预测值。在长序列训练中,logits及其梯度占用了大量内存,因为它们的维度与序列长度和词汇表大小成正比。

StreamBP对不同训练目标(SFT、GRPO和DPO)采用了不同的处理策略。以监督微调(SFT)为例,其目标函数是:LSFT(logits, Y) := -Σ(t=1到T-1) log softmax(logitst,:)Yt。StreamBP将logits沿序列维度均匀分割成D个块,然后对每个块分别计算梯度,再累加起来。这样,我们只需要存储当前块的logits和梯度,而不是整个序列的logits和梯度,内存使用量降低至原来的1/D。

接下来是Transformer层,它包含注意力机制和多层感知机(MLP)两部分。传统方法需要存储Q、K、V、M、O、Hup、Hgate、Hout等大量中间值,占用大量内存。StreamBP的关键观察是:计算H^(i)_out对W的梯度只依赖于O^(i)、Q^(i)、K^(:i)和V^(:i)。

基于这一观察,StreamBP对每个块依次执行分区注意力和MLP计算,只存储当前块的激活值,然后累加梯度。这样,StreamBP只需要存储Q^(i)、K、V、M^(i)、O^(i)、H^(i)_up、H^(i)_gate和H^(i)_out,内存使用量降低至原来的约1/D。特别是,当使用分组查询注意力(grouped query attention)时,K和V的内存开销更是只有Q的1/G(G为分组大小)。

三、StreamBP的内存效率与计算效率分析

除了内存效率,StreamBP还有一个令人惊喜的特性:它实际上比标准反向传播需要更少的计算量!这就像我们发现了一种既节省空间又节省时间的菜谱一样。

在长序列训练中,Transformer层最耗费计算资源的操作是计算注意力分数S = QK^T。标准实现需要2T^2d的浮点运算,而StreamBP只需要(1+D)T^2d/D的浮点运算,大约减少了一半。这是因为StreamBP利用了语言模型的因果结构,在计算S^(i) = Q^(i)K^(:i)^T时只使用K^(:i)而非完整的K。

当然,StreamBP也有一些额外开销。每次分区梯度计算都需要从高带宽内存(HBM)加载模型权重W到寄存器中进行计算,这会产生额外的开销。同时,StreamBP大约减少了一半注意力掩码的HBM吞吐量。这些开销直接取决于分区数量D。

研究团队还开发了通信高效的分布式StreamBP版本,以支持多GPU训练。在分布式训练中,梯度通信和参数通信需要特别优化,因为不同GPU上的不同分区需要共享信息。分布式StreamBP的设计确保了通信成本与标准BP相当,同时保持了内存效率。

四、实验结果:StreamBP的惊人表现

研究团队对StreamBP进行了全面的实验评估,结果令人印象深刻。实验基于Qwen 3系列模型(4B、8B、14B和32B),在A800-80GB GPU上进行。

在内存使用方面,在80GB内存限制下,与梯度检查点技术相比,StreamBP将最大序列长度提高了2.8-5.5倍,与不使用梯度检查点的标准BP相比则提高了23.4-36.3倍!这就像原本只能在小餐盘上放3个菜,现在可以放15个菜了。

具体来说,对于Qwen 3-8B模型,StreamBP能处理长达110k的序列,而梯度检查点只能处理20k;对于Qwen 3-14B模型,StreamBP能处理40k序列,而梯度检查点只能处理9k;对于Qwen 3-32B模型(仅使用LoRA梯度),StreamBP能处理16k序列,而梯度检查点只能处理5.7k。

在时间消耗方面,StreamBP实际上比梯度检查点更快!随着序列长度增加,这种加速效果越发明显。例如,在序列长度为27k时,StreamBP比梯度检查点快约12.9%。这是因为StreamBP减少了计算注意力分数的浮点运算量。

值得注意的是,StreamBP的内存使用量与序列长度呈线性关系,这意味着它的序列长度扩展能力可以直接转化为批量大小的扩展,从而加速训练。例如,对于8B模型的SFT,StreamBP使得批量大小可以比梯度检查点大约4.5倍。

在不同训练目标下,StreamBP都表现出色。对于SFT、GRPO和DPO,它都显著增加了最大序列长度。例如,对于4B模型的SFT,StreamBP能处理200k序列,而梯度检查点只能处理28.5k。在分布式训练(Deepspeed ZeRO-2)环境中,StreamBP同样表现出色,将最大序列长度提高了约5-5.6倍,同时BP速度更快。

研究团队还研究了分区大小D对内存和时间消耗的影响。当序列长度相对较小时(小于10k),不同分区大小的BP时间相近。但随着序列长度增加,过小的分区大小会引入可观的开销。幸运的是,较大的分区大小只会引入很小的额外内存开销。因此,对于长序列训练,可以使用相对较大的分区大小来最大化训练效率。

五、StreamBP的广泛应用前景与局限性

StreamBP可以无缝集成到任何Transformer模型的训练流程中,为各种需要处理长序列的任务提供支持。特别是在训练推理模型(reasoning models)时,这种能力尤为重要。

无论是使用强化学习(RL)还是监督微调(SFT)来增强模型的推理能力,都需要处理包含详细推理过程的超长序列。随着训练的进行,序列长度往往会增加。StreamBP正是为这种场景量身定制的解决方案。

当然,StreamBP也有一些局限性。目前,它不支持MoE(混合专家)或多模态模型,不过研究团队表示,这些问题可以通过简单的实现扩展来解决,因为底层原理保持不变。此外,StreamBP的分区大小对BP时间有明显影响,这种开销可以通过使用融合反向算子来减少HBM吞吐量。

总的来说,StreamBP为大语言模型的长序列训练提供了一个极具价值的解决方案。它不仅大幅降低了内存使用量,还在某些情况下提高了计算效率。这使得研究人员能够在有限的硬件资源下训练模型处理更长的序列,从而增强模型在复杂任务上的能力。

想象一下,你原本只能阅读一篇短文,现在可以阅读一本完整的书,而且速度还更快——这就是StreamBP带来的巨大飞跃。对于AI研究和应用来说,这无疑是一个激动人心的进步,将有助于开发出能够理解和处理更长、更复杂内容的AI系统。

如果你对这项研究感兴趣,可以访问GitHub(https://github.com/Ledzy/StreamBP)查看代码,或阅读完整的研究论文(arXiv:2506.03077v1)深入了解技术细节。我相信,随着这种技术的普及,我们将看到更多能够进行深入推理和处理超长文本的AI系统涌现。

分享至
1赞

好文章,需要你的鼓励

推荐文章
----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.-