微信扫一扫,关注公众号

  • 科技行者

  • 算力行者

见证连接与计算的「力量」

首页 帝国理工学院突破性成果:不用定制芯片也能让AI模型飞速运行

帝国理工学院突破性成果:不用定制芯片也能让AI模型飞速运行

2026-03-19 21:39
分享至:
----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.-
2026-03-19 21:39 科技行者

这项由帝国理工学院领导的研究发表于2026年3月的arXiv预印本平台,论文编号为arXiv:2603.09555v1,为AI模型的运行效率带来了革命性的突破。有兴趣深入了解的读者可以通过该论文编号查询完整研究内容。

想象一下,你买了一款最新的智能手机,但发现它只能在特定品牌的充电器上充电,换了其他充电器就无法正常工作。这正是目前AI领域面临的一个重要问题。当前最先进的状态空间模型(比如著名的Mamba-2)就像这款挑剔的手机一样,它们通常只能在英伟达的GPU上高效运行,因为它们依赖于专门定制的计算内核。这种限制就像给AI的发展戴上了一副沉重的枷锁,让许多研究者和开发者望而却步。

帝国理工学院的研究团队发现了一个令人兴奋的解决方案。他们仔细分析了Mamba-2模型的工作原理,发现这个模型的核心算法——状态空间对偶性(SSD)——实际上具有一些非常特殊的数学特性。这些特性包括对角状态结构、可分块的递归计算,以及主要由批量矩阵运算组成的计算过程。研究团队意识到,这些特性恰好与现代编译器(特别是谷歌的XLA编译器)最擅长优化的计算模式完美匹配。

这个发现的意义就像发现了一把万能钥匙。传统上,要让Mamba-2这样的模型高效运行,就必须为特定的硬件编写专门的计算内核,这就像为每种锁都单独配一把钥匙。而研究团队发现的方法则让标准的编译器就能自动生成高效的代码,就像有了一把能开所有锁的万能钥匙。

一、化繁为简的智慧:为什么编译器也能胜任专业工作

要理解这项研究的突破性,我们需要先了解什么让某些计算任务特别适合编译器优化。这就像烹饪一样,有些菜谱特别适合按部就班地照做,而另一些则需要大厨的即兴发挥。

Mamba-2的状态空间对偶算法恰恰属于前一种情况。它的计算过程可以分解为几个关键步骤,每一步都有着清晰的结构。首先是分块处理,就像把一大块肉切成相等大小的小块来烹制,算法将输入序列分割成固定大小的块(每块256个令牌)。在每个块内部,原本需要逐步进行的递归计算可以转换为并行的矩阵乘法,这就像同时在多个炉灶上煎制不同的牛排,而不是一块接一块地烹制。

更重要的是,所有的重计算都可以表达为批量的张量运算。这些运算的形式非常规整,就像工厂流水线上的标准化操作,编译器可以轻松识别并优化它们。当算法需要条件判断时(比如下三角掩码的因果结构),它使用静态掩码而不是运行时分支。这意味着所有的条件都在编译时就已经确定,不会在运行时产生不可预测的分支,保持了整个计算流程的可预测性。

研究团队还发现了另一个关键因素:精度管理。他们精心设计了不同计算步骤使用的数值精度。残差连接使用float32精度以防止误差在多层网络中累积,衰减参数在对数空间中使用float32精度然后在计算时取指数以避免下溢,归一化层将输入转换为float32进行方差计算后再转回原精度。这些精心设计的精度标注替代了定制内核提供的细粒度数值控制。

二、实现理论与现实的完美结合:O(1)缓存的技术突破

传统的Transformer模型在处理长序列时面临着一个根本性问题:它们需要存储所有先前标记的键值对,这意味着内存使用量会随着序列长度线性增长。这就像你每写一个字都要把之前写的所有字都复印一遍保存起来,序列越长,需要的存储空间就越大。

状态空间模型的理论优势在于它们可以将整个历史压缩到一个固定大小的状态中。这个状态的大小不依赖于已经处理的序列长度,理论上实现了O(1)的内存复杂度。但是,要在实际系统中真正实现这个理论优势,需要解决一个关键的工程挑战:如何在设备上维护和更新这个状态,而不需要与主机进行频繁的数据传输。

研究团队的解决方案是将Mamba2Cache实现为JAX的PyTree数据结构。这个缓存包含了每层的SSM状态和卷积状态,都注册为JAX可以跟踪和放置的数据结构。这样,jax.jit编译器和jax.lax.fori_loop函数就能在编译时识别这些状态,并将整个解码循环编译为设备上的程序,无需主机同步。

在实际测试中,这种设计带来了显著的性能提升。对于130M参数的模型,Python主机循环每秒只能生成662个标记,而编译后的设备循环能达到1588个标记每秒,提升了2.4倍。虽然随着模型规模增大这个差距会缩小(因为单步计算时间增加),但对于中小型模型,循环执行策略直接决定了能否实现O(1)缓存的理论优势。

生成下一个标记的过程变得非常简洁:首先更新卷积状态,这只需要在滑动窗口中插入新输入;然后进行单步SSM递归更新,形式为h_t = A * h_{t-1} + B * x_t;最后通过线性投影生成输出。整个过程完全在设备上完成,没有主机设备间的数据传输开销。

三、从算法到代码的精妙转换:SSD算法的JAX实现

将数学算法转换为高效的计算机代码是一门艺术,需要在保持算法本质的同时,让代码适合现代编译器的优化策略。研究团队在这方面展现了精湛的技艺,他们将复杂的状态空间对偶算法压缩到了不到60行的Python代码中。

连续时间状态空间模型的数学表达相当优雅。它通过微分方程h'(t) = Ah(t) + Bx(t)和输出方程y(t) = Ch(t) + Dx(t)来描述系统的动态行为。当这个连续系统通过零阶保持进行离散化后,得到递归形式h_t = A_bar * h_{t-1} + B_bar * x_t。Mamba-2的创新在于让B、C和Δ参数依赖于输入,并将A限制为每个注意力头的对角标量。

关键的突破来自于对这个递归的重新组织。在一个包含L个标记的块内,顺序递归可以等价地表达为结构化的矩阵向量乘积:Y_diag = (L ⊙ CB^T)X,其中L是一个下三角矩阵,包含从累积衰减因子exp(segsum(A·Δ))导出的值。这个变换将块内的计算从顺序变为并行,而块间的状态传播仍然是顺序的但计算量很轻。

研究团队将这个数学表达直接映射到JAX的原语上。他们发现XLA编译器能够自动识别和优化这种计算模式。element-wise操作链(softplus → clip → exp → einsum)被融合为单个megakernel,而einsum运算被直接映射到目标设备的矩阵单元上。

掩码操作的处理展现了编译器友好设计的重要性。当算法需要条件计算时(如衰减矩阵的下三角因果结构),静态掩码(通过jnp.tril应用)保持了融合优化,而运行时分支则会破坏融合图并可能强制同步。应用tril到预计算矩阵让XLA将操作融合到周围的element-wise链中,而运行时分支或主机端条件则会破坏融合图。

四、性能表现的全面验证:从理论到实践的完美印证

要验证一个新方法的有效性,最有说服力的证据就是在真实硬件上的性能表现。研究团队在谷歌云TPU v6e上进行了全面的性能评估,这个芯片具有918 TFLOPS的BF16峰值计算能力和1600 GB/s的HBM带宽。

在自回归生成任务中,缓存策略带来了显著的性能优势。对于不同规模的模型,缓存版本的吞吐量保持恒定,与序列长度无关,而非缓存版本的性能随序列长度增加而急剧下降。以2.7B参数模型为例,在序列长度4096时,缓存版本每秒生成95个标记,而非缓存版本只有3个标记每秒,差距超过30倍。

内存使用模式同样印证了O(1)缓存的理论优势。缓存解码的峰值内存使用保持常数(2.7B模型约10.9GB),而非缓存路径的内存使用随序列长度线性增长。在序列长度4096时,2.7B非缓存路径消耗超过16GB内存,而缓存路径始终保持10.9GB的恒定占用。

硬件利用率方面的结果展现了编译器生成代码的高质量。在预填充(compute-bound)任务中,XLA在2.7B规模达到约140 TFLOPS,占TPU v6e峰值性能的15%。虽然这个数字看起来不高,但实际上已经接近单序列预填充在该芯片上的屋顶线限制,因为达到算术平衡需要约574 FLOPs/byte的比率,这是批次大小为1无法达到的。

在解码(memory-bound)任务中,表现更加令人印象深刻。2.7B模型达到了64%的HBM带宽利用率,这意味着编译器生成的代码能够有效地利用可用的内存带宽。带宽利用率在不同序列长度下的变化很小,每个模型的变化幅度都在1.7个百分点以内,显示了性能的稳定性。

数值正确性验证确保了这种优化没有以牺牲准确性为代价。与PyTorch/CUDA参考实现相比,贪婪解码在64个生成步骤中产生了完全相同的标记序列。虽然由于不同的归约树排序导致的float32累积差异产生了微小的数值漂移(约2×10^-4),但这种漂移在功能上是无害的,因为离散的标记轨迹完全相同。

五、平台通用性的惊人表现:一套代码跑遍天下

这项研究最令人印象深刻的特点之一是其出色的平台通用性。相同的JAX源代码可以在CPU、NVIDIA GPU和Google Cloud TPU上运行,无需任何修改。这种通用性的实现源于研究团队巧妙地利用了XLA编译器的跨平台能力,以及他们对算法结构的深刻理解。

在NVIDIA A100 GPU上的测试结果证明了这种跨平台兼容性的价值。虽然由于硬件差异性能数字有所不同,但同样的O(1)缓存优势和线性扩展特性得到了完美复现。这意味着研究者和开发者不再需要为不同的硬件平台维护不同的代码版本,大大降低了开发和维护成本。

更重要的是,这种通用性为AI模型的部署开辟了新的可能性。传统上受限于NVIDIA GPU的Mamba-2现在可以在各种硬件上运行,包括Google的TPU、苹果的M系列芯片,甚至是普通的CPU。这种硬件无关性对于AI技术的普及和应用具有重要意义。

编译时间的分析揭示了这种方法的一个权衡。XLA编译器为2.7B模型在序列长度4096时需要多达43秒的编译时间。虽然在生产服务中编译后的HLO程序可以在请求间重复使用,但对于迭代研究来说,这是kernel-free方法的主要延迟开销。不过,考虑到一次编译可以服务大量推理请求,这个开销在实际应用中是可以接受的。

六、技术创新的深层价值:超越单一模型的普遍意义

这项研究的价值远远超出了对单个模型的优化。研究团队识别出的算法特征——对角状态结构、可分块递归、静态控制流和einsum主导的计算——为评估其他状态空间模型的编译器友好性提供了一个清晰的框架。

消融研究(ablation studies)揭示了每个设计决策的重要性。静态掩码与动态循环的对比显示,前者在1.3B模型上达到42,631 tokens/s的预填充性能,而后者只有7,330 tokens/s,性能下降82.8%。虽然两种方法产生完全相同的输出,但动态变体破坏了XLA的融合链,导致性能急剧下降。

精度管理的重要性通过衰减指数化的实验得到了证明。在BF16精度下进行衰减计算会在130M模型的24层中累积误差,产生足以影响输出分布的误差(最大绝对误差0.013)。float32上转换是正确性要求,而非优化选择。

这些发现为未来的模型设计提供了重要指导。算法设计者现在知道哪些特征使模型适合编译器代码生成,哪些设计选择会破坏优化。这种知识可以指导新模型的开发,使它们从一开始就对编译器友好,而不是事后添加优化。

研究的局限性也为未来工作指明了方向。固定的块大小(L=256)和批次大小(1)是简化假设,实际应用可能需要更灵活的配置。对于涉及数据依赖内存访问、warp级同步或数据依赖控制流的操作,当前的标准原语方法可能无法充分利用硬件特性。不过,SSD算法巧妙地避免了所有这些复杂情况,这正是它适合编译器优化的原因。

说到底,这项研究代表了AI系统工程方法的一个重要转变。它表明,通过深入理解算法结构和现代编译器能力,我们可以在不牺牲性能的情况下实现更好的硬件通用性和代码简洁性。对于状态空间模型这样的新兴架构,定制内核现在是可选的而非必需的,这为更广泛的研究和应用奠定了基础。

这种编译器优先的方法可能会影响未来模型架构的设计。当算法设计者知道某些计算模式特别适合编译器优化时,他们可能会在创新时考虑这些约束,创造出既在数学上优雅又在工程上实用的解决方案。这种理论与实践的结合正是推动AI领域持续进步的关键所在。

Q&A

Q1:状态空间模型相比传统Transformer有什么优势?

A:状态空间模型最大的优势是内存效率。Transformer需要存储所有历史标记的信息,序列越长占用内存越多,而状态空间模型只需要固定大小的状态来压缩整个历史,实现了理论上的O(1)内存复杂度。这使得它们在处理长序列时更加高效。

Q2:为什么这项研究说定制内核变成可选的了?

A:传统上像Mamba-2这样的状态空间模型必须依赖专门为NVIDIA GPU编写的CUDA内核才能高效运行。这项研究发现,Mamba-2的核心算法具有特殊的数学结构,正好适合现代编译器(如XLA)自动优化,因此不再需要手写专门的计算内核,标准编译器就能生成高效代码。

Q3:这种编译器友好的方法有什么实际意义?

A:最直接的意义是硬件通用性。同一套代码可以在CPU、NVIDIA GPU、Google TPU等不同硬件上运行,不再被限制在特定平台上。这大大降低了AI模型的部署门槛,让更多研究者和开发者能够使用先进的状态空间模型,推动AI技术的普及应用。

分享至
0赞

好文章,需要你的鼓励

推荐文章
----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.- ----..---.-...-/--...-.-......./-...-....-..--../-............-.-