传统Next-Token Prediction训练大模型时,一次只预测下一个词。多Token预测(MTP)则一次性预测多个未来词,加速学习并提升推理能力。本文拆解MTP原理,探讨它对大模型训练效率与泛化能力的影响。
传统的大模型训练以Next-Token Prediction(单Token预测)为基础。模型一次只看下一个词,学得慢,且容易忽略长距离依赖。最近,Meta AI等机构提出Multi-Token Prediction(MTP),让模型一次预测多个未来词。
MTP并不复杂:输入序列后,模型同时预测第 t+1、t+2……t+k 个Token。输出端配置多个独立的预测头(heads),每个头负责一个位置。这些头共享Transformer主干网络,只在最后几层分叉。训练时,所有头的损失加权求和,一起反向传播。
这种设计鼓励模型更快捕捉语义结构。比如预测“我昨天去了”之后,传统模型只学“商场”这个词,但MTP会同时预测“商场”“看电影”“很开心”,相当于强迫模型在更宽的上下文中做决策。
MTP的核心优势在于梯度信号更丰富。单Token预测时,每个位置只有一条梯度路径;MTP则引入k-1条额外路径,训练效率提升明显。实验表明,MTP在相同数据量下能让模型困惑度(perplexity)更低。
更重要的是,MTP天然适合生成任务。推理时,模型可以一次输出多个Token(通过缓存各预测头的概率),再基于某种规则(如自回归或并行解码)生成。这能降低推理延迟,对对话系统、联网搜索等场景很有价值。
国内大模型竞赛中,百度文心、阿里通义、DeepSeek都关注过类似思路。DeepSeek在自家模型上尝试过MTP变体,发现对长文本生成(如代码补全、论文生成)的连贯性提升明显。但MTP也带来训练显存增加——预测头越多,参数越大。
另一个问题是如何选择k值。k太小效果不明显,k太大则训练不稳定,且推理时的并行收益递减。目前主流实验取k=3~5。
MTP不是取代Next-Token Prediction,而是补充。它让模型在相同计算量下学到更多——这正是AI训练效率竞赛中的关键破局点。
参考:Meta AI《Better & Faster Large Language Models via Multi-Token Prediction》
免费获取企业 AI 成熟度诊断报告,发现转型机会
关注公众号

扫码关注,获取最新 AI 资讯
3 步完成企业诊断,获取专属转型建议
已有 200+ 企业完成诊断