什么是推测解码?

推测解码(Speculative Decoding)是一种用于加速大型语言模型(LLM)推理的创新技术。它通过使用一个较小的"草稿模型"来预测下一个 token,然后用更大的"目标模型"来验证这些预测,从而显著减少推理时间。

技术原理

基本思想

推测解码的核心思想是:

  1. 使用一个小而快的草稿模型(draft model)来生成多个候选 token
  2. 使用大而准确的目标模型(target model)来并行验证这些候选 token
  3. 接受所有被验证正确的 token,拒绝第一个错误的 token 并重新采样

实现细节

详细工作流程

推测解码的具体工作流程如下:

  1. 初始预测:目标模型(LLM)预测第一个 token
  2. 草稿生成:草稿模型基于当前上下文生成多个候选 token(通常 3-5 个)
  3. 并行验证:目标模型并行验证所有草稿模型生成的候选 token
  4. 比较接受:逐个比较目标模型和草稿模型的输出:
    • 如果相同:接受该 token,继续比较下一个
    • 如果不同:停止接受,进入下一轮预测
  5. 下一轮循环:从未被接受的 token 位置开始新的预测循环
    自回归的LLM生成
    Speculative decoding

伪代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def speculative_decoding(prompt, draft_model, target_model, max_tokens, k=5):
generated = []

while len(generated) < max_tokens:
# 目标模型预测第一个token(如果需要)
if not generated:
next_token = target_model.sample(prompt)
generated.append(next_token)
continue

# 草稿模型生成k个候选token
current_context = prompt + generated
draft_tokens = draft_model.generate(current_context, num_tokens=k)

# 目标模型并行验证所有候选token
target_probs = []
for i in range(len(draft_tokens)):
verify_context = current_context + draft_tokens[:i]
next_token_probs = target_model.predict_next_token(verify_context)
target_probs.append(next_token_probs)

# 逐个比较并接受token
accepted_tokens = []
for i, draft_token in enumerate(draft_tokens):
target_prob = target_probs[i][draft_token]
draft_prob = draft_model.get_probability(draft_token, current_context + draft_tokens[:i])

accept_prob = min(1, target_prob / draft_prob) if draft_prob > 0 else 0

if random.random() < accept_prob:
accepted_tokens.append(draft_token)
else:
break # 第一个不匹配就停止

generated.extend(accepted_tokens)

# 如果全部接受,继续下一轮;否则从目标模型重新采样
if len(accepted_tokens) == k:
continue
else:
resample_context = current_context + accepted_tokens if accepted_tokens else current_context
next_token = target_model.sample(resample_context)
generated.append(next_token)

return generated

模型选择

  • 草稿模型:通常选择参数量较小、推理速度快的模型
  • 目标模型:需要高精度的大模型,负责最终的质量保证

并行验证

推测解码的关键优势在于能够并行验证多个候选 token。

性能优势

加速比

推测解码通常能达到 2-3 倍 的加速比,具体取决于:

  1. 草稿模型的质量:草稿模型越准确,接受率越高
  2. 候选长度:合适的候选序列长度平衡了并行性和准确性
  3. 模型大小差异:目标模型和草稿模型的参数量比例

考虑到推测解码代码逻辑独立于模型前向,引入额外开销使得整体加速效果只有 0.3-0.5 倍

硬件要求

推测解码对硬件的要求相对灵活:

  • GPU 内存:需要同时加载两个模型,需要显存更大
  • 计算资源:草稿模型的计算开销较小,但并行 verify 计算开销倍增
  • 通信开销:模型间需要高效的数据传输

挑战与限制

技术挑战

  1. 模型一致性:草稿模型和目标模型需要保持输出分布的一致性
  2. 候选长度选择:过长的候选序列可能降低接受率
  3. 错误传播:错误的预测会影响后续生成质量

适用场景

推测解码最适合:

  • 批量推理任务
  • 对延迟敏感的应用
  • 资源受限的环境

不太适合:

  • 需要极高准确性的关键任务
  • 非常短的文本生成