自回归语言模型每次前向传播只生成一个token,这一"串行瓶颈"在推理阶段造成巨大的计算浪费——GPU的并行算力被严重低估,驱动了多令牌预测方向的系统性探索。
Vaswani等人在《Attention Is All You Need》中确立了Transformer的自回归解码范式:每步以前序所有token为条件预测下一个token,训练目标为最大化 $\log p(x_t | x_{
以NAT(Non-Autoregressive Transformer, Gu et al. 2018)为代表的研究试图一次性生成所有token,彻底打破串行约束。但实验反复证明:独立预测各位置token会导致严重的多模态崩塌(multimodal collapse)——模型无法协调相邻token间的依赖,生成质量大幅下降。这一时期的教训是:完全去除自回归依赖代价过高,需要更温和的折中方案。
Cai et al.(2023)提出Medusa,在冻结的LLM主干上附加多个轻量"草稿头"(draft heads),每个头独立预测未来第 $k$ 步的token,再用树形注意力(tree attention)并行验证多条候选路径。这是第一个在工业级模型上实现无损加速的多令牌方案,Meta、Together AI等机构随即跟进。
Meta在《Better & Faster Large Language Models via Multi-Token Prediction》(Gloeckle et al., NeurIPS 2024)中提出将多令牌预测作为训练目标而非推理技巧:模型在训练时同时优化未来 $n$ 步的预测,共享主干表示,每步有独立输出头。实验表明这不仅加速推理,还显著提升了代码生成等需要长程规划的任务质量——因为预测未来多步迫使模型学习更全局的语义表示。
以MARS为代表的新一代工作发现:无需从头训练,通过轻量级微调(LoRA量级的参数量)即可让已有模型获得多令牌预测能力,大幅降低了应用门槛,使该技术向边缘部署和小团队普及。
设序列长度为 $N$,标准自回归训练目标为: $$\mathcal{L}_{\text{AR}} = -\sum_{t=1}^{N} \log p_\theta(x_t \mid x_{
多令牌预测的整体逻辑是:用一次前向传播的共享表示驱动多个并行输出头,推理时以树形验证将串行步数折叠,训练时以多步监督信号丰富梯度。
输入token序列经过标准Transformer主干,得到每个位置的隐状态 $h_t \in \mathbb{R}^d$。这一步与普通LLM完全相同,无架构修改。关键在于:$h_t$ 必须同时承载"当前位置语义"和"对未来多步有预测力的全局信息",这一双重压力正是MTP训练的核心价值所在。
在主干顶部附加 $n$ 个轻量输出头(通常为单层线性投影或小型MLP),每个头 $k$ 独立预测 $x_{t+k}$。头的参数量极小(约为主干的 $1\%$),不显著增加显存和计算。为什么不用 $n$ 个独立Transformer层?因为深层特征提取已由主干完成,额外层只需做"任务适配",轻量头足够。
$n$ 个头各自输出 top-$m$ 候选token,组合成一棵候选树(共 $m^n$ 条路径)。为控制验证开销,实践中用束搜索或动态剪枝将树规模限制在可接受范围(通常 $<64$ 个节点)。树形注意力(tree attention)通过修改注意力掩码,使主干在一次前向传播中并行验证所有路径——这是Medusa的核心工程贡献。
python # 伪代码:树形注意力掩码构建 def build_tree_mask(tree_paths): # tree_paths: List[List[int]], 每条路径是token索引序列 n_nodes = sum(len(p) for p in tree_paths) mask = torch.zeros(n_nodes, n_nodes, dtype=torch.bool) for path in tree_paths: for i, node in enumerate(path): # 每个节点只能看到其祖先节点 mask[node, path[:i+1]] = True return mask主干对树中每个节点重新计算概率,从根到叶贪心选取最长一致前缀作为本轮输出。若第 $k$ 步草稿token的概率超过阈值(贪心)或通过拒绝采样(保证分布无偏),则接受并继续;否则截断。这一机制保证了输出分布与原始模型完全等价(无损),是区别于非自回归方法的关键保证。
训练阶段,$n$ 个头的损失加权求和后统一反传至主干。梯度从多个未来步同时流入 $h_t$,相当于主干在每个位置同时接受来自 $n$ 个监督信号的约束,实验表明这显著改善了需要长程规划的任务(如代码补全、数学推理)的表现,因为模型被迫学习"下一步之后还会发生什么"。
多令牌预测已成为工业级LLM推理加速的主流方案之一。Meta在Llama 3系列中集成了MTP训练目标,实测代码生成任务提升显著;Together AI、Groq等推理服务商将Medusa类方案作为标配加速层,实现2–3倍吞吐提升而无质量损失。对音视频生成领域,该技术正被迁移至音频token序列生成(如EnCodec token流),有望将实时语音合成的延迟进一步压缩。其价值不仅在加速,更在于揭示了"多步预测作为训练信号"这一正则化视角的普适性。
当前核心开放问题:①接受率建模——如何在训练时显式优化接受率而非事后调整;②动态头数——不同难度token应激活不同数量的预测头,静态 $n$ 是次优的;③与投机解码的统一理论——MTP草稿头与独立草稿模型在信息论层面的等价条件尚不清晰;④多模态扩展——视频/音频token流的时序依赖结构与文本不同,树形验证策略需重新设计。