OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  MSA:记忆稀疏注意力机制,一个可扩展的、端到端可训练的潜在记忆框架

MSA:记忆稀疏注意力机制,一个可扩展的、端到端可训练的潜在记忆框架

 
  integration ·  2026-03-19 21:05:06 · 27 次点击  · 0 条评论  

MSA: 记忆稀疏注意力

一个可扩展、端到端可训练的潜在记忆框架,支持 1 亿令牌上下文

论文代码模型

License: MIT


📝 摘要

长期记忆对通用智能至关重要,然而完全注意力的瓶颈将大多数大语言模型(LLM)的有效上下文长度限制在 128K–1M 之间。现有的尝试,如混合线性注意力、固定大小的状态记忆(例如 RNN)以及RAG/智能体等外部存储方案,要么在极端规模下遭受精度快速衰减和延迟增长,要么缺乏端到端的可微分性或动态内存维护能力,或者需要复杂的处理流程。我们提出了记忆稀疏注意力(MSA):一个端到端可训练、可扩展的稀疏 潜在状态记忆框架。其核心思想包括:

  • 可扩展稀疏注意力 + 文档级 RoPE(并行/全局),在训练和推理中均实现近线性复杂度
  • KV 缓存压缩内存并行推理引擎,在 2×A800 GPU 上实现 1 亿令牌的吞吐量;
  • 内存交错机制,支持跨分散内存片段进行多轮、多跳推理。

在长上下文问答和“大海捞针”(NIAH)基准测试中,MSA 超越了同骨干网络的 RAG、最优的 RAG 栈以及领先的长上下文模型。在空前的 16K→1 亿令牌范围内,MSA 表现出 < 9% 的性能衰减,为解耦记忆容量与推理能力提供了一条可行路径。

从 16K 扩展到 1 亿令牌:MSA 将 Top-k 选择与稀疏注意力融合,保持端到端可微分的同时,允许在推理时解耦文档。在 MS MARCO 数据集上,MSA 保持 <9% 的性能衰减,并展现出强大的外推能力。
部分基线曲线因其上下文限制而提前终止。

图 1: 16K→1 亿令牌的扩展曲线
图 1: MSA 在超长上下文下的可扩展性


✨ 核心贡献

  • 记忆稀疏注意力(MSA):一种端到端可训练、可扩展的稀疏注意力层,结合文档级 RoPE,实现了 O(L) 复杂度,并在 16K→1 亿令牌范围内保持 <9% 的性能衰减。
  • KV 缓存压缩 + 内存并行:分层存储(GPU 驻留路由键,CPU 存储内容 K/V)、分布式评分以及按需传输,使得在 2×A800 GPU 上实现 1 亿令牌推理成为可能。
  • 内存交错:自适应交替执行“生成式检索 → 上下文扩展 → 生成”,显著提升了跨文档的多跳推理能力。
  • 全面评估:在长上下文问答和 NIAH 任务上,MSA 优于同骨干网络的 RAG、最优的 RAG 流程以及顶尖的长上下文模型,展现出卓越的稳定性准确性

🧩 整体设计

架构

MSA 将检索与生成集成到一个单一的可微分循环中。文档的潜在状态(K/V/Kᵣ)通过分块平均池化进行压缩。一个路由投影器通过余弦相似度计算相关性(在注意力头上平均池化,然后取令牌级最大值),选择 Top‑k 文档,然后将它们压缩后的 K/V 与查询的局部 K/V 拼接,用于自回归解码。路由仅应用于上层网络;下层网络保持独立的文档处理,以实现层次化对齐。

  • 并行(文档级)RoPE:每个文档的位置从 0 开始重置,防止训练时短推理时长之间的位置漂移,使得 64k 训练能够外推到 1 亿。
  • 全局 RoPE(活动上下文):查询的起始索引偏移 k(Top‑k 检索到的块),保持因果顺序:背景 → 查询 → 生成

图 2: MSA 层(稀疏注意力 + 文档级 RoPE)

图 2: MSA 层
图 2: 记忆稀疏注意力层及并行/全局 RoPE


推理流程

MSA 采用三阶段流程(图 3):

  1. 全局记忆编码(离线):对语料库进行前向传播,缓存分块池化后的 (K̄, V̄, K̄ᵣ)
  2. 在线路由与上下文组装:将查询投影为 Qᵣ,与 K̄ᵣ 匹配以选取 Top‑k,然后仅加载选中的 K̄/V̄ 并与局部上下文拼接。
  3. 稀疏生成:在稀疏上下文上进行自回归生成。

内存并行K̄ᵣ 分片到多个 GPU 上(查询广播 → 本地评分 → 全局归约)。内容 K̄/V̄ 保留在主机 DRAM 中,并在被选中时异步获取——平衡 VRAM吞吐量,以实现 1 亿令牌的部署。

图 3: 三阶段推理与内存交错

图 3: 推理流程
图 3: 离线编码 → 在线路由 → 稀疏生成;可选的多轮交错用于多跳推理


🚀 实验结果

实验设置
问答任务:9 个数据集(MS MARCO v1, NQ, DuReader, TriviaQA(10M), NarrativeQA, PopQA, 2WikiMultiHopQA, HotpotQA, MuSiQue),记忆库大小 277K→1000 万令牌,评估指标:LLM 评判(0–5 分)
NIAH(RULER):8 个子任务,32K→100 万令牌,报告平均准确率。
骨干网络:Qwen3‑4B‑Instruct‑2507。与同骨干网络的 RAG 以及最优的 RAG 栈(KaLMv2 + 大型生成器,可选重排序器)进行比较。

表 2: MSA 对比同骨干网络 RAG(Qwen3‑4B)

总结:平均得分 3.760,优于标准 RAG(+16.0%)、RAG+重排序(+11.5%)以及使用其最佳@k 的 HippoRAG2(+14.8%);在同骨干网络组内,除 NarrativeQA 外,MSA 在所有数据集上领先。

数据集 令牌数 Qwen3-4B R@1 R@5 R@10 Qwen3-4B (RR) R@1 R@5 R@10 HippoRAG2 R@1 R@5 R@10 MSA (自适应)
MS MARCO v1 7.34M 2.893 3.011 3.005 2.934 3.032 3.017 2.676 3.005 3.019 4.141
Natural Questions 1.47M 3.452 3.374 3.297 3.494 3.408 3.385 3.338 3.389 3.374 3.545
DuReader 277K 3.726 3.579 3.594 3.848 3.618 3.607 2.941 3.485 3.415 4.155
TriviaQA (10M) 10M 4.133 4.414 4.273 4.313 4.375 4.391 4.188 4.430 4.367 4.621
NarrativeQA 538K 1.611 2.567 2.860 3.638 3.492 3.536 1.959 2.628 2.655 3.395
PopQA 1.18M 2.959 3.273 3.299 3.315 3.264 3.266 3.111 3.249 3.249 3.433
2WikiMultiHopQA 722K 1.065 3.055 3.136 1.187 3.057 3.159 1.045 3.180 3.330 4.280
HotpotQA 1.35M 2.252 3.582 3.787 2.642 3.990 4.022 3.230 3.770 3.970 4.061
MuSiQue 1.41M 0.936 1.752 1.928 1.144 1.960 1.965 1.020 1.907 2.095 2.211
平均 2.559 3.179 3.242 2.946 3.355 3.372 2.612 3.227 3.275 3.760

表 2: 同骨干网络 RAG 与 MSA 对比(@1/@5/@10 对比 MSA 自适应)


表 3: MSA 对比最优 RAG 栈(大型骨干网络)

总结:与 KaLMv2+Qwen3‑235BKaLMv2+Llama‑3.3‑70B(带/不带重排序)相比,MSA 在 4/9 的数据集上取得最佳分数,平均得分为 3.760,相对于各最强配置的相对提升分别为 +7.2%+5.0%+10.7%+5.4%。在少数数据集(如 MuSiQue)上的差距主要归因于参数量差异和内在推理能力。

数据集 KaLMv2 + Qwen3‑235B R@1 R@5 R@10 Qwen3‑235B (RR) R@1 R@5 R@10 KaLMv2 + Llama‑3.3 R@1 R@5 R@10 Llama‑3.3 (RR) R@1 R@5 R@10 MSA (自适应)
MS MARCO v1 2.846 3.028 3.027 2.886 3.020 2.995 2.649 2.904 2.919 2.881 2.955 2.952 4.141
Natural Questions 3.711 3.670 3.694 3.621 3.610 3.645 3.675 3.674 3.662 3.756 3.665 3.647 3.545
DuReader 4.044 3.991 3.978 3.973 3.932 3.891 4.051 3.846 3.742 3.967 3.776 3.780 4.155
TriviaQA (10M) 4.367 4.656 4.578 4.492 4.320 4.555 4.273 4.740 4.719 4.547 4.703 4.695 4.621
NarrativeQA 1.413 2.130 2.427 3.212 3.427 3.375 1.290 2.123 2.382 3.150 3.263 3.317 3.395
PopQA 2.810 3.347 3.396 3.268 3.380 3.376 2.787 3.298 3.305 3.337 3.384 3.362 3.433
2WikiMultiHopQA 2.646 3.579 3.582 1.855 3.381 3.583 1.339 3.263 3.445 1.651 3.332 3.541 4.280
HotpotQA 3.497 4.090 4.225 3.341 4.141 4.194 3.070 3.896 4.127 3.428 4.145 4.203 4.061
MuSiQue 1.988 2.462 2.647 1.801 2.522 2.605 1.704 2.317 2.258 1.895 2.462 2.614 2.211
平均 3.036 3.439 3.506 3.161 3.526 3.580 2.760 3.340 3.396 3.179 3.521 3.568 3.760

表 3: SOTA RAG 栈(强检索器 + 大型生成器 + 可选重排序器)与 MSA 对比


图 4: RULER NIAH 稳定性(32K→1M)

总结:MSA 在 100 万令牌时保持 94.84% 的准确率。未经修改的骨干网络在超过 128K 后性能崩溃(在 100 万令牌时降至 24.69%)。混合线性注意力的长上下文模型在 ≥128K/256K 时性能明显下降。外部记忆智能体(例如 RL‑MemoryAgent‑14B)保持稳定,但绝对准确率较低,且比 MSA 表现出更陡峭的性能衰减。

图 4: RULER NIAH 32K→1M
图 4: 准确率 vs 上下文长度(越高越好)


实现说明

  • 训练:使用辅助路由损失进行 1589.5 亿令牌的持续预训练,随后进行两阶段 SFT(8k→64k 课程学习)。
  • 消融实验(论文表 4):课程扩展、内存交错、持续预训练以及注入原始文本均有显著贡献;移除它们会导致 5%–37% 的性能下降(取决于任务)。

引用

@misc{chen_2026_19103670,
  author       = {Chen, Yu and
                  Chen, Runkai and
                  Yi, Sheng and
                  Zhao, Xinda and
                  Li, Xiaohong and
                  Zhang, Jianjin and
                  Sun, Jun and
                  Hu, Chuanrui and
                  Han, Yunyun and
                  Bing, Lidong and
                  Deng, Yafeng and
                  Chen, Tianqiao},
  title        = {MSA: Memory Sparse Attention for Efficient End-to-
                   End Memory Model Scaling to 100M Tokens
                  },
  month        = mar,
  year         = 2026,
  publisher    = {Zenodo},
  doi          = {10.5281/zenodo.19103670},
  url          = {https://doi.org/10.5281/zenodo.19103670},
}

致谢

本仓库和文档页面由 MSA 作者维护。有关项目更新,请访问主页:https://evermind.ai/

27 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私政策 ·  服务条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 25 ms
Developed with Cursor