GLM-4-Z1 模型设计做错了一件事

DeepSeek-R1 API 在 OpenAI 兼容 API 基础上为思考模型增加了 "reasoning_content" 字段,可以将 "<think>" "</think>" 中间的内容独立返回,非思考部分通过 "content" 字段返回。

SGLang/vLLM 等框架也支持了这样的返回,比如 SGLang 可以在启动时增加参数 "--reasoning-parser deepseek-r1",让服务器像 DeepSeek-R1 API 那样返回。

Chatbox 等支持 OpenAI 兼容 API 的聊天 APP,对 DeepSeek-R1 风格的返回结果支持也更好,上面截图就是一个例子,当思考部分以 "reasoning_content" 字段返回时,显示效果与正文不同,背景是灰色,而且可折叠。

上一篇博客中提到我给 SGLang 添加了 GLM-Z1 模型的支持,然后我发现 GLM-Z1 系列模型无法做到思考内容和非思考内容分开输出。如果你在启动服务时增加了 "--reasoning-parser deepseek-r1" 参数,那么所有生成的内容都会放在 "reasoning_content" 字段返回,而 "content" 为空。Chatbox 显示效果如下所示:

我仔细研究了一下,发现了问题在 GLM-Z1 模型上:GLM-Z1 系列模型在设计时,没有像 DeepSeek-R1 那样将 "</think>" 编码为一个独立的 Token。

GLM-Z1 系列模型输出的 "</think>" 由 3 个 Token 组成:"</"、 "think" 和 ">",这就会导致 框架在流式逐字输出时,没法简单地判断一个新生成的 Token 是否为 "</think>",以决策在何时结束 "reasoning_content" 字段。

可以通过下面代码简单验证:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("GLM-Z1-9B-0414")
print(tokenizer("</think>\n"))

DeepSeek-R1 是在 tokenizer.json 中标记了 "</think>" 作为 "added_tokens",QwQ-32B 则把 "</think>" 放在了 added_tokens.json 中。这就意味着,DeepSeek-R1 和 QwQ-32B 输出的 "</think>" 是 1 个 token,很方便框架去完整地比对它是否存在。框架的确也可以兼容非一个 token 的 "</think>",但编码更为复杂,效率也会偏低。

建议思考模型的设计者都注意一下这个小小的细节。

有关 GLM-4-0414 的 SGLang 推理支持

在智谱的 GLM-4-0414 系列模型发布后,我观察到一个有意思的现象。GLM-4-0414 相较于 GLM-4 修改了模型结构,初期仅支持通过 transformers 库推理,但你去搜索它有哪些推理支持的话,什么也找不到,被大量的垃圾自媒体文章淹没,很多相较于官方模型仓库的 README 来说都是 0 信息量。

有些垃圾自媒体,可能连个 AI 都不如。

说回来 GLM-4-0414,虽然它名字看起来是 GLM-4 的延续,但是仔细看 config.json 你就会发现,GLM-4 的模型结构是 ChatGLMModel,而 GLM-4-0414 的模型结构是 Glm4ForCausalLM。这就导致很多推理框架都要对其进行重新适配。

vllm 可能是最早支持 GLM-4-0414 的开源框架,看起来是智谱员工提的 PR,但是第一版实现有两个 bug,会导致模型加载失败或者输出错误结果。由于对 GLM-Z1-9B-0414 模型在报告中声称大幅优于 DeepSeek-R1-Distill-Qwen-14B 过于好奇,我忍不住自己尝试在 SGLang 里适配了一下 GLM4-4-0414,见 PR#5485

嗯,学到了很多东西,也踩了两个 bug 其中的一个,哭。就说说这两个 bug 吧,其实都是源自于对 vllm/SGLang 库算子的不了解。

在 transformers 库中,很多算子仅实现其算子名表示的朴素的功能,但是 vllm/SGLang 代码库的一些算子,除了其朴素功能以外,往往还通过更多的参数实现对前后计算的融合、简化,导致其实际计算需要深入探究。

BUG 1: 融合 RMSNorm

以比较基础的 LLaMA 为例,transformers 中的 LLaMA DecoderLayer 的实现是这样的(删去了一些无关细节),这个跟大家理解的 Transformer 模型结构伪代码是容易对上的。

  1. def forward():
  2. residual = hidden_states
  3. hidden_states = self.input_layernorm(hidden_states)
  4. # Self Attention
  5. hidden_states, self_attn_weights = self.self_attn()
  6. hidden_states = residual + hidden_states
  7. # Fully Connected
  8. residual = hidden_states
  9. hidden_states = self.post_attention_layernorm(hidden_states)
  10. hidden_states = self.mlp(hidden_states)
  11. hidden_states = residual + hidden_states
  12. outputs = (hidden_states,)
  13. return outputs

但是 vllm/SGLang 中的实现是这样的:

  1. def forward():
  2. # Self Attention
  3. if residual is None:
  4. residual = hidden_states
  5. hidden_states = self.input_layernorm(hidden_states)
  6. else:
  7. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  8. hidden_states = self.self_attn(...)
  9. # Fully Connected
  10. hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
  11. hidden_states = self.mlp(hidden_states)
  12. return hidden_states, residual

粗看你会有点懵,仔细研究就会发现,SGLang 里面通过给 RMSNorm 传入第二个参数,实现了 Norm 与 Add 的融合。但是这种融合需要调整计算顺序,影响了参数和返回值的类型,并且也影响了最后一次 model.norm() 的计算。

相似的还有 SiluAndMul() 算子。

BUG 2: get_rope 参数

GLM-4-0414 的 config.json 中提供了 partial_rotary_factor 参数,作用于 head_dim 上。有两种应用方法,一种是提前计算好 rotary_dim = int(partial_rotary_factor * self.head_dim),然后把这个参数传进去;另一种是令 rotary_dim = head_dim,然后传入 partial_rotary_factor。

  1. def get_rope(
  2. head_size: int,
  3. rotary_dim: int,
  4. max_position: int,
  5. base: int,
  6. is_neox_style: bool = True,
  7. rope_scaling: Optional[Dict[str, Any]] = None,
  8. dtype: Optional[torch.dtype] = None,
  9. partial_rotary_factor: float = 1.0,
  10. ) -> RotaryEmbedding:
  11. ...
  12. if partial_rotary_factor < 1.0:
  13. rotary_dim = int(rotary_dim * partial_rotary_factor)

我重复犯的这个错误就是将 partial_rotary_factor 应用了两遍,rotary_dim 也计算了,partial_rotary_factor 参数也传了(因为参考了 vllm 实现和旧的 SGLang chatglm 实现,甚至看到别人提交的 bugfix 都不认为这里有错),其实就相当于应用了 partial_rotary_factor * partial_rotary_factor。这个 BUG 的现象就是导致 GLM-4 模型的输出大概率陷入死循环而无法结束。

BTW:transformers 库里的 glm4 代码没有读这个参数,而是硬编码在 apply_rotary_pos_emb() 算子中,所以这不是一个可调参数。

PR#5485 还在 Review 中,有需要的朋友可以使用 https://github.com/solrex/sglang/tree/glm4 这个分支支持 GLM-4-0414 系列模型在 SGLang 的推理。

git clone -b glm4 https://github.com/solrex/sglang.git
cd sglang
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
python3 -m sglang.launch_server --model /workspace/GLM-Z1-9B-0414 --enable-torch-compile --torch-compile-max-bs 128 --cuda-graph-max-bs 128 --host 0.0.0.0 --port 8000

Llama-4 的 expert 参数组织问题和 FP8 量化方法

Llama-4 把模型结构里 expert 的参数组织搞得太恶心了。BF16 参数还好,如果想做一下量化,就会面临一堆麻烦。

先说一下 Llama-4 开源参数的问题。

HF BF16版本:3 维 experts 参数,gate_proj 和 up_proj 融合

Llama-4-Scout/Maverick 的 HuggingFace 版本参数 BF16 版本中,路由 expert 专家参数是按照 3 维存储的,而且 up_proj 和 gate_proj 放在了同一个 key 里。比如 Llama-4-Scout 的专家参数在 safetensors 文件中是这样存储的:

TensorsShapePrecision
language_model.model.layers.n.feed_forward.experts.down_proj
[16,8192,5120]
BF16
language_model.model.layers.n.feed_forward.experts.gate_up_proj
[16,5120,16384]
BF16

其中 n 是层数索引。Maverick 的 experts Shape 只是从 [16, , ] 变成了 [128, , ],其余都一样。

这个形状是与 Transformers 库中 llama4 的模型代码保持一致的:

// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py
class Llama4TextExperts(nn.Module):
    def __init__(self, config: Llama4Config):
        ...
        self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
        self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
        ...

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        ...
        gate_up = torch.bmm(hidden_states, self.gate_up_proj)
        gate, up = gate_up.chunk(2, dim=-1)  # not supported for DTensors
        next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
        ...

HF FP8版本:2 维 expert 参数,gate_proj 和 up_proj 分离

在 FP8 量化版本 Llama-4-Maverick-17B-128E-Instruct-FP8 中,expert 参数又被拆成了二维的,通过 key 中的索引标识属于哪个专家

TensorsShapePrecision
language_model.model.layers.n.feed_forward.experts.m.down_proj.weight[5120, 8192]F8_E4M3
language_model.model.layers.n.feed_forward.experts.m.down_proj. weight_scale[5120, 1]BF16
language_model.model.layers.n.feed_forward.experts.m.gate_proj.weight[8192, 5120]F8_E4M3
language_model.model.layers.n.feed_forward.experts.m.gate_proj. weight_scale[8192, 1]BF16
language_model.model.layers.n.feed_forward.experts.m.up_proj.weight[8192, 5120]F8_E4M3
language_model.model.layers.n.feed_forward.experts.m.up_proj. weight_scale[8192, 1]BF16

其中 n 是层数索引,m 是层中的专家数索引。

之前说到,Transformers 库中 modeling_llama4.py 是只支持融合模型的,那这种格式的参数怎么加载?

哎,人家用了一个办法:如果读到模型配置里有量化配置,在加载模型前修改一下模型结构。硬是在 transformers/quantizers/base.py 中增加了一个 _convert_model_for_quantization 方法,如果模型有子 module 叫做 "Llama4TextExperts",在量化 preprocess 的时候就给替换成 SequentialLlama4TextExperts 实现,不使用原始的 Llama4TextExperts 实现。

注意哦,这个替换对所有量化模型都生效。这种特化方法,要放在我的团队里,CR 估计都过不了。

怎么量化 llama4-scout ?

FB 官方提供了 128 专家的 FP8 模型,但是没有提供 16 专家的 FP8 量化模型。毕竟 16 专家模型也 200 多 G,如果想量化 16 专家模型,该怎么做呢?

meta 官方在 source/en/model_doc/llama4.md 里推荐的方法,是加载模型时使用 FbgemmFp8Config 进行 online 量化,这个我没跑成功,看着错误像是只支持单卡跑 fbgemm 量化,但是 H800 显存不够。如果这个问题可以解决,欢迎留言告诉我方法。

$ torchrun --nproc-per-node=4 test-16-fbgemm.py
self.pre_quantized False
Loading checkpoint shards: 100%|███████| 50/50 [00:45<00:00, 1.09it/s]
...
[rank0]: File "/workspace/venv-fbgemm/lib/python3.10/site-packages/transformers/integrations/fbgemm_fp8.py", line 52, in forward
[rank0]: x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
[rank0]: File "/workspace/venv-fbgemm/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
...
[rank0]: File "/workspace/venv-fbgemm/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 486, in propagate_op_sharding_non_cached
[rank0]: raise NotImplementedError(
[rank0]: NotImplementedError: Operator fbgemm.quantize_fp8_per_row.default does not have a sharding strategy registered.

咱又没有那么大显存的卡,只能想别的办法,能不能仿照 Llama-4-Maverick-17B-128E-Instruct-FP8 来转换出来一个 16 专家的 FP8 模型呢?

Llama-4-Maverick-17B-128E-Instruct-FP8 怎么转出来的?

首先了解一下 llama4 的转换脚本:

  • 在 github llama-models 代码库里,有一个 llama4/scripts/quantize.py 脚本,是用来将原始的 pytorch 模型参数,通过 fbgemm 转成 FP8 量化的 pytorch 模型参数。
  • 在 transformers 代码库里,有一个 llama4/convert_llama4_weights_to_hf.py 脚本,是用来将原始的 pytorch 模型参数,通过映射表映射到 huggingface 的 safetensors 模型参数。

然后来看 llama4 发布的模型:

它不是通过 convert_llama4_weights_to_hf.py 转换 Llama-4-Maverick-17B-128E-Instruct-FP8-Original 来的,因为:

1. FP8-Original 的量化是通过 fbgemm 做的,FP8 模型的量化是用 compressed-tensor 做的;

2. FP8-Original 是分 TP 做的量化,与 FP8 的量化方法也不同。

它也不是通过 llm-compressor(compressed-tensor) 转换 Llama-4-Maverick-17B-128E-Instruct 来的。因为 compressed-tensor 目前仅支持 Linear 算子的量化,但前面说过,加载原始 BF16 模型用的是 bmm 算子。

convert_llama4_weights_to_hf.py 中有一个 _OFFLINE_QUANT_COMPATIBLE 的参数,可以控制是否对专家进行融合。但是这样转出来的模型,modeling_llama4.py 是无法加载的。

我猜测 meta 线下可能替换了 modeling_llama4.py 中的专家层到 SequentialLlama4TextExperts,成功加载模型以后,然后再通过 llm-compressor 进行的模型转换

然后就这样试了一下,先替换 transformers 库源码 modeling_llama4.py 中的 experts 组,并从源码安装 transformers:

--- a/src/transformers/models/llama4/modeling_llama4.py
+++ b/src/transformers/models/llama4/modeling_llama4.py
@@ -153 +153,2 @@ class Llama4TextMoe(nn.Module):
-        self.experts = Llama4TextExperts(config)
+        from transformers.quantizers.base import SequentialLlama4TextExperts
+        self.experts = SequentialLlama4TextExperts(config)

然后再执行:

export OFFLINE_QUANT_COMPATIBLE=1
python3 src/transformers/models/llama4/convert_llama4_weights_to_hf.py --instruct --convert_checkpoints --input_dir /workspace/Llama-4-Scout-17B-16E-Instruct-Original --output_dir /workspace/Llama-4-Scout-17B-16E-Instruct-Split-Experts

然后再使用 llm-compressor 库,用下面的 recipe 参数对模型进行量化,这样就成功了。转换出来的模型仓库内容基本与 Llama-4-Maverick-17B-128E-Instruct-FP8 相同。

recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=[
're:.*lm_head',
're:.*self_attn',
're:.*router',
're:.*vision_model',
're:.*multi_modal_projector',
're:.*shared_expert',
're:.*feed_forward.gate_proj',
're:.*feed_forward.up_proj',
're:.*feed_forward.down_proj'
],
)

为什么?

我是想不通 llama4 为啥要把简单的事情搞这么复杂,这样一改,很多开源库都要去适配这种格式的 MoE 参数,但功能上看起来又没有任何变化。做融合的 expert 参数有什么实际的收益吗?

在 32 张 L40S/L20 上运行 DeepSeek-R1/V3 原版 FP8 模型

上文讲到,FP8 模型之所以无法 TP32 运行,主要因为 DeepSeek R1/V3 模型保存的参数是 FP8 128x128 量化的。Attention 还好,128 个头做 TP16 或者 TP32 都没问题,问题主要出在专家的计算上。

前三层 MLP 的 intermediate_size 是 18432,18432 做 TP16 是 1152,相当于 9 个 128。但如果做了 TP32,就是 4.5 个 128,这会导致参数无法切分在 128 边界上,无法支持按 128 block 量化。路由专家也类似,moe_intermediate_size 做 TP32 相当于 0.5 个 128。

缩小 DeepSeek-R1/V3 的量化块到 64x64

如果想支持 TP32,一个很显然的路径就是把量化方式改成按 64x64 分块量化。本来这是一个比较复杂的操作,但我想了一个取巧的办法:直接把 128x128 的缩放系数,复制到 4 份。

为了方便理解这个方案,我画了一张图。假设我们有一个 4x8 的 INT32 矩阵,按照 4x4 block 量化到 INT8,它会分成 1x2 个 4x4 的块,每块一个缩放系数,那就是 1x2 个缩放系数。如下图所示,第一个块的缩放系数是 7.3465,它是通过第一个块里的最大绝对值 |-933|/127 得到的,同理第二个缩放系数来自 974/127。

那如果我想将它的量化 block 缩小到 2x2,理论上我应该计算每个 2x2 block 的最大绝对值,然后 /127 得到缩放系数,这样精度损失最小。可是我嫌麻烦,偷个懒,我直接把 4x4 的缩放系数复制 4 份,虽然精度有损失,但好处是在计算上与 4x4 的量化结果完全一致。换句话说,就是原汁原味,纯血参数。

将这个逻辑迁移到 DeepSeek FP8 量化的 128x128 block 缩小到 64x64,原理是一样的,也是将 scale 参数矩阵进行 2x2 等值扩充。通过非常简单的参数处理,就能够实现将 DeepSeek 原始模型转成 64x64 的分块量化,然后就可以用 SGLang 加载运行了。

运行方法

我们以昨天发布的 DeepSeek-V3-0324 为例,逐步说明如何使用这种方法在 L40S 和 L20 上运行 FP8 满血+纯血版的 DeepSeek-V3-0324,不需要等待美团再发布 INT8 版本。

假设你已经下载好了模型,在 /workspace/DeepSeek-V3-0324/。那你需要先下载我的开发分支,并通过源代码安装它(如果遇到困难,建议你在 SGLang 的开发 Docker 中执行它):

git clone -b l40s-dsfp8 https://github.com/solrex/sglang.git
cd sglang
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python

然后用下面这个脚本,将 128x128 量化的 DeepSeek-V3-0324,转换到 64x64 量化的 DeepSeek-V3-0324-Block64x64:

python3 scripts/resize_block_size.py /workspace/DeepSeek-V3-0324/

当你在 4 台(或 8 台)机器上都完成了 SGLang 安装和参数拷贝后,就可以用下面的命令来启动 SGLang 服务了。注意替换 MASTER_IP 和 TCP_IFACE 到正确的值。

# MASTER_IP: 主节点 IP
# TCP_IFACE: 主网卡接口名,可通过 ifconfig 获取

# 主节点
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP_IFACE GLOO_SOCKET_IFNAME=TCP_IFACE python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32 --host 0.0.0.0 --port 8000

# 从节点 1
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP_IFACE GLOO_SOCKET_IFNAME=TCP_IFACE python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32

# 从节点 2
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP_IFACE GLOO_SOCKET_IFNAME=TCP_IFACE python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32

# 从节点 3
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP网卡 GLOO_SOCKET_IFNAME=TCP网卡 python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32

性能

在 MASTER 节点使用下面的命令进行性能测试(需要先下载测试数据集 ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json),输入固定 200,输出固定 200,并发 128,测试两轮。

python3 -m sglang.bench_serving --backend sglang-oai --dataset-path /workspace/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-name random --random-range-ratio 1 --random-input-len 200 --random-output-len 200  --request-rate 128 --num-prompt 256 --max-concurrency 128 --host localhost --port 8000

我测试的性能指标是:

=========== Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 108.13
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 50917
Request throughput (req/s): 2.37
Input token throughput (tok/s): 473.49
Output token throughput (tok/s): 473.49
Total token throughput (tok/s): 946.98
Concurrency: 127.38
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 53802.80
Median E2E Latency (ms): 53530.89
---------------Time to First Token----------------
Mean TTFT (ms): 7768.52
Median TTFT (ms): 7769.45
P99 TTFT (ms): 12003.95
---------------Inter-Token Latency----------------
Mean ITL (ms): 232.35
Median ITL (ms): 220.48
P95 ITL (ms): 242.17
P99 ITL (ms): 270.50
Max ITL (ms): 10250.34
==================================================

对照上一篇博客《刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据》,看起来 FP8 Block 量化的性能比 INT8 Channel 量化的性能要差一些。

代码改动

这次的代码改动不大,主要是参数转换脚本,和一些针对 64x64 的 tunning,以及 warning 的修复。整理好后我会提个 PR 给 SGLang,但这次我不确定这个 PR 是否会被接受,感兴趣的同学可以直接看这个 commit:https://github.com/solrex/sglang/commit/03d34078d8d65983aabc0386391743cc43f535ed

地震

正当我写到这的时候,忽然手机通知地震了。生平第一次收到,记录一下。我完全没震感,但是有朋友感觉到了。

刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据

前一篇博客《使 SGLang 支持在 32 张 L40S 上运行 DeepSeek-R1》中提到我那非常特殊的 L40S 显卡配置,结果发现是个大乌龙。

首先,这台机器有 PCIE 4.0 Switch,每张 Switch 上插了 4 张 L40S 显卡。我误会的 PCI-to-PCI 应该是挂了一个给监控屏用的小显卡。

其次,有一张网卡插错位置了。本来应该每 4 张显卡配 1 张双口网卡,其中一张没插到对应的 Switch 上。

最后,3 张网卡 6 个网口,只启用了 1 个网口,有 5 个网口没启用。

所以,实际上所有机内通信走的都是 PCIE,所有跨机通信走的都是主网卡,这……只能怪自己没经验,默认以为交过来的环境都是对的。

折腾了几天,总算搞对了,正确的拓扑如下:

主网卡两个网口做了链路聚合,用作 TCP 通信;PCIE Switch 上的 4 个网口专用作 RDMA 通信。同机两个 NUMA 域双卡之间通信,走 PCIE 大概 11GB/s,走 GDRDMA 能提升到 19GB/s。最关键的是多机通信,从单网口的小水管提升到了 4 网口 GDRDMA。

采用与上篇文章同样不严谨的测试方式:

# 128 并发
[TP0] Decode batch. #running-req: 128, #token: 86816, token usage: 0.44, gen throughput (token/s): 777.92, #queue-req: 0, # 之前 565.73, 1.37x
# 32 并发
[TP0] Decode batch. #running-req: 32, #token: 8923, token usage: 0.04, gen throughput (token/s): 457.66, #queue-req: 0, # 之前 260.73,1.75x
# 4 并发
[TP0] Decode batch. #running-req: 4, #token: 1437, token usage: 0.01, gen throughput (token/s): 153.98, #queue-req: 0, # 之前 42.3,3.64x
# 1 并发
[TP0] Decode batch. #running-req: 1, #token: 482, token usage: 0.00, gen throughput (token/s): 49.02, #queue-req: 0, # 之前 26,1.88x

固定 200 输入,200 输出,128 并发,2 轮请求:

============ Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 87.57
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 51023
Request throughput (req/s): 2.92
Input token throughput (tok/s): 584.64
Output token throughput (tok/s): 584.64 # 之前 391.47,1.49x
Total token throughput (tok/s): 1169.28
Concurrency: 127.23
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 43524.37
Median E2E Latency (ms): 43246.71
---------------Time to First Token----------------
Mean TTFT (ms): 7225.06
Median TTFT (ms): 6769.52
P99 TTFT (ms): 14229.37
---------------Inter-Token Latency----------------
Mean ITL (ms): 182.87
Median ITL (ms): 162.95
P95 ITL (ms): 166.60
P99 ITL (ms): 202.39
Max ITL (ms): 13711.39
==================================================

可以看到,修复网络本身存在的问题以后,推理性能提升还是很显著的。

one more thing

前一篇文章提到: L40S/L20 虽然支持 FP8 精度,却不能运行 FP8 的 DeepSeek-V3/R1。这个问题搞定了,我已经实现在 L40S 上运行原始 FP8 参数的 DeepSeek-R1,满血+纯血。等我整理一下代码,下篇博客来介绍一下。

使 SGLang 支持在 32 张 L40S/L20 上运行 DeepSeek-R1

我提交的 PR: Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. #4418 [1] 已经合入到了 SGLang 的主干,也许这是第一个用 PCIE 互联的 GPU 小卡跑通 DeepSeek-R1 推理的例子。

有一些遇到的问题分享一下,在适配别的 GPU 时可以用来参考。

背景

因为 128x128 block 量化的 DeepSeek-R1/V3,参数维度在除 32 时,会遇到商无法被 128 整除的问题,所以即使 L40S 支持 FP8,也无法直接用 TP32 运行 DeepSeek-R1/V3。

感谢美团提供了 channel 量化的 DeepSeek-R1 参数 DeepSeek-R1-Channel-INT8 [2],并且在 SGLang 代码库做了适配。这让我可以在 48G 显存的 L40S 上尝试一下运行满血版(int8 量化) DeepSeek-R1,但尝试过程没有我预期的顺利,遇到了不少问题。

问题

一、shared memory OutOfResources

  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/triton_ops/extend_attention.py", line 356, in extend_attention_fwd
_fwd_kernel[grid](
...
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 374, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 102400, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

服务启动正常,收到推理请求后,Attention kernel 报 shared memory 资源不足。我刚开始没理解这个问题,后来看了下代码,发现 triton attention 代码中对不同类型的 GPU 架构设置了不同的 block size。L40S 属于 SM89 架构,但 SM89 都归类到了 SM80 架构里。

查了一下 CUDA 编程手册,SM89 的 shared memory 大小是 100K,但 SM80 是 160K。我猜测大概率就是这个问题,所以给 SM89 单独加了个分支,缩小了 SM89 的 block size,解决了这个问题。

后来发现 sglang/test/srt/test_triton_attention_kernels.py 可以完美复现这个问题,而我却傻乎乎地每次重启整个服务去测试正确性。

二、gemm executioin failed RuntimeError

  File "/usr/local/lib/python3.10/dist-packages/sgl_kernel/ops/__init__.py", line 118, in int8_scaled_mm
return torch.ops.sgl_kernels.int8_scaled_mm(
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: gemm executioin failed, error: Error Internal

收到推理请求后,有很大的概率触发这个错误,但也有小概率能完成一次推理。所以我没想到是算子问题,我以为是显存不足。调了半天各种显存占用的参数,后来没办法了才回过头来看实际的算子调用。

又是跟上面类似的问题,sgl_kernel 中自定义的 int8 gemm 算子将 SM89 归类到 SM80 进行矩阵计算的 dispatch。这显然会遇到与上面类似的问题,但是我又不知道 SM89 该怎么进行 dispatch,看起来需要做很多 benchmark 或者计算才能确定。

于是我就去翻 TensorRT-LLM 和 vLLM,让我给翻到了 vLLM 的实现,我就照着 vLLM 对 SM89 的 dispatch 逻辑抄了一遍。这次我学乖了,先看看有没有 test。跑通了 sglang/sgl-kernel/tests/test_int8_gemm.py,才去进行集成测试。

三、sub-optimal MoE

Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json

这是一个次要问题,看起来是一个专门的配置没有找到。后来我研究了一下,应该是 triton 版本的 fused_moe 需要读取一个在每种类型的 GPU 上都 benchmark 过的最好配置来运行。

python  benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model /workspace/DeepSeek-R1-Channel-INT8 --tp-size 32 --dtype int8_w8a8 --tune

我按照文档的要求,在 L40S 上跑了一下 bench,把最终输出的文件拷贝到对应的位置就好了。这个 bench 真的要跑好久,大概两三个小时。

性能

我这 4 台 L40S 的硬件配置有些特殊:它是单机 8 卡,PCIE 4.0 连接到主机,但它既不是 2-2-4,也不是 2-2-8。

20250322:其实是硬件配置错误,详见:《刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据

其中 4 张卡直插 PCIE 上,另外 4 张卡通过一个 PCI-to-PCI Bridge 插到 PCIE 上。为了弥补这样连接的带宽缺陷,PCI-to-PCI Bridge 上还接了 2 个 100Gb/s 的 RDMA 网卡。主机上每个 NUMA 域,分别也有 2 个 100Gb/s 的 RDMA 网卡。

所以这 8 张显卡,6 张网卡,拓扑如下所示:

我也实在算不出来这玩意儿的互联带宽。逻辑上来说,这大概相当于用 4 PCIE 4.0 GPU + 2 RDMA 网卡的性能,所以也许这个 4 x 8卡,实际上相当于 8 x 4卡。以下性能评测结果供参考。

下面是加载时的显存使用情况:

[TP0] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=22.43 GB, mem usage=21.27 GB.
[TP0] Memory pool end. avail mem=7.95 GB
[TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=7.92 GB
[TP0] Capture cuda graph end. Time elapsed: 411.41 s. avail mem=5.89 GB. mem usage=2.02 GB.
[TP0] max_total_num_tokens=201723, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=2049, context_len=163840

下面是使用固定 200 token 输入时,不同并发下的净 decode 速度。因为 bench 速度太慢,我也懒得等最终的结果,所以这里直接节取了 log。

# 128 并发
[TP0] Decode batch. #running-req: 128, #token: 86816, token usage: 0.43, gen throughput (token/s): 565.73, #queue-req: 0,
# 32 并发
[TP0] Decode batch. #running-req: 32, #token: 8923, token usage: 0.04, gen throughput (token/s): 260.73, #queue-req: 0,
# 4 并发
[TP0] Decode batch. #running-req: 4, #token: 1439, token usage: 0.01, gen throughput (token/s): 42.30, #queue-req: 0,
# 1 并发
[TP0] Decode batch. #running-req: 1, #token: 482, token usage: 0.00, gen throughput (token/s): 26.00, #queue-req: 0,

这是固定 200 输入,200 输出,128 并发,2 轮请求的完整压测结果:

============ Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 130.79
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 50992
Request throughput (req/s): 1.96
Input token throughput (tok/s): 391.47
Output token throughput (tok/s): 391.47
Total token throughput (tok/s): 782.94
Concurrency: 127.49
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 65135.99
Median E2E Latency (ms): 64974.27
---------------Time to First Token----------------
Mean TTFT (ms): 17554.19
Median TTFT (ms): 19216.02
P99 TTFT (ms): 21662.98
---------------Inter-Token Latency----------------
Mean ITL (ms): 239.84
Median ITL (ms): 220.95
P95 ITL (ms): 233.20
P99 ITL (ms): 299.49
Max ITL (ms): 16077.21
==================================================

链接

[1] https://github.com/sgl-project/sglang/pull/4418

[2] https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8

[3] https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh

理解 FlashMLA 在 DeepSeek MLA 计算过程中的位置和作用

读过 DeepSeek V2 Paper 的人可能对下面这张图印象非常深刻,非常直观地揭示了 MLA 算法的基本原理。但结合这张图或者 MLA 公式去看 DeepSeek 在开源周发布的 FlashMLA [1],可能就一脸懵逼了。

图1: MLA 与其它 Attention 对比 [2]

理解 FlashMLA 代码库解决的是怎样一个问题,还需要更进一步的理解 MLA 计算的优化过程。

DeepSeek MLA 公式

下面这张图是 DeepSeek V2 Paper 中的 MLA 公式,相信读到这里的人都知道这是个啥。所以我不做解释,放这里主要是为了方便与下面的图进行交叉阅读。

图2: MLA 计算公式 [2]

MLA Naive 实现

最直接的实现 MLA 的方式,就是实现图 2 MLA 的所有计算过程,这个计算过程也与图 1 完全一致。我以 DeepSeek V3 的参数为例,做了下面这张图。

为简化复杂度,这张图里隐藏了两个维度,batch size 和 seq len,或者你可以把它简单地理解成,仅输入一个 token 给模型进行计算的状态。就像我在前面的博客《DeepSeek-V3 MTP 工程实现思考》中做的那样:你仅输入“how”这一个单词给模型,看看模型能给你生成什么。

图3: MLA Naive 实现,转化为 MHA 计算

每个绿色的方框代表一次矩阵乘/线性变换,框内是参数矩阵的名字和维度;每个白色的方框代表一个中间结果 tensor,框内是张量名字和维度;黄色的方框则代表核心的 Attention 计算,也就是图 2 中的公式 (46)。参数矩阵和中间结果 tensor 的名字与图 2 保持一致。

在 Naive 实现中,512 维的 Latent KV cKV 被映射回对应 128 个 head,每个 head 128 维的 K kC 和 V vC,然后再拼接上位置向量 kR最终形成标准的 q、k、v,输入到标准的 Multi Head Attention 进行 Attetion 计算。与其他常见模型中 MHA 的唯一不同,可能是 head dim 192 不是 2 的 n 次方。

Naive 实现最直观,但它的问题是在 Decode 计算时性能不够好。Decode 计算时,输入的 Q 往往只有一个 token,但需要用到所有前缀 token 的 k 和 v,也就是我们通常说的 KV Cache。Naive 实现有两种选择:

① 缓存 Latent KV。缓存规模小,但 Latent KV 缓存不能直接送 MHA 计算,还得经过 WUK 和 WUV 的线性映射,可以看到这是两个规模不小的矩阵计算,而且每轮都得重复计算。

② 缓存 KV。缓存规模大,不用重复计算,性能好。但 MLA 的一大好处就是 KV Cache 压缩,这样显存内能缓存更多 token,支持更大的 batch 和 prefix cache。如果缓存 KV,在显存上对比 MHA 就完全没有优势了。

所以,Naive 实现可能会用于 Prefill,但在 Decode 计算时需要更好的计算方法

MLA 优化实现

很多人把下面这种 MLA 的优化称为矩阵吸收[3],来源是 DeepSeek V2 里面这样说:

Fortunately, due to the associative law of matrix multiplication, we can absorb WUK into WUQ, and WUV into WO. Therefore, we do not need to compute keys and values out for each query. Through this optimization, we avoid the computational overhead for recomputing kCt and vCt during inference.

但我更喜欢把它理解成矩阵乘法交换律。因为实际上大家发现,提前将两个参数矩阵乘起来,即把 (WUQ)TWUK 的计算结果做为新的参数矩阵,在性能上还不如分开计算[3]。既然实际计算过程是交换矩阵计算过程,从“矩阵吸收”角度思考反而更绕了。

图4: MLA 优化实现,转化为 MQA 计算

上图中的两个虚线箭头,显示了在优化的计算过程中,哪些参数矩阵被交换了位置。它们能交换的原因,就是从数学上这样修改是等价的(矩阵乘法交换律)。

与图 3 相比,可以看到输入给 Attention 的 q、k、v 形状发生了明显的变化。q 的形状由 128x192 变化成了 128x576,k 的形状由 128x192 变化成了 576,v 的形状由 128x128 变化成了 512。这样一来,原来的 KV 就不存在了,新的计算过程中只剩下 Latent KV 了。而且实际上 V 也不存在了,因为 V 就是 K 的前 512 维。

再回看图 1,你会发现,这不就是 MQA 么?而这就是实际上 FlashMLA 代码库解决的问题:提供了一个专门为 q k head dim 为 576,v head dim 为 512,v 与 k 的前 512 维重叠,q head 数不超过 128(TP 下会变少)设计,仅支持 H800/H100 GPU 的 MQA 优化算子。

简单来说:虽然这个库叫做 FlashMLA,但它提供的 flash_mla_with_kvcache() 是个 MQA 算子,只不过这个 MQA 的输入形状有些特殊。

小知识

为什么会这样呢?因为开源软件约定俗成的 Attention 算子封装,仅仅指图 2 中公式(46)这一行,是不包含前后的线性变换的。开源推理框架允许用户通过配置选择不同的 Attention 算子实现,比如 FlashAttention、FlashInfer、Triton 实现等。

虽然 MLA 算法的核心在前后的线性变换,FlashMLA 算子却不能提供这些变换。这些线性变换只能被实现在模型建模的 modeling 代码的 MLA 模块中,比如 SGLang 代码库 python/sglang/srt/models/deepseek_v2.py 文件中的 DeepseekV2AttentionMLA [4] Module。

引用

[1] https://github.com/deepseek-ai/FlashMLA

[2] DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model, https://arxiv.org/pdf/2405.04434v5

[3] DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子, https://zhuanlan.zhihu.com/p/700214123

[4] https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_v2.py

2 行代码校验大模型(如DeepSeek-R1)权重文件下载完整性

很多人在 DeepSeek-V3/R1 爆火之后,都希望体验本地运行“满血版”模型。但是满血版模型的权重参数文件有 600 多个 G,光权重文件就拆成了 163 个。

当你受不了 HuggingFace 官网的下载速度,用其它方法或者渠道获得了权重文件后,怎么确认这些权重文件是完整无损坏的呢?

这里介绍一个最简单的方法,仅需要 2 行代码。

环境

前提 1,你已经 clone 了不含权重文件的模型 git 仓库。以 DeepSeek-R1 为例,通过下面命令可以仅 clone 代码文件到 DeepSeek-R1 目录下:

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1

前提 2,你已经用某种方法下载好了权重文件。请将这些权重文件放到已 clone 的 git 仓库目录内,以 DeepSeek-R1 为例,就是将 163 个 *.safetensors 文件移动到 DeepSeek-R1 目录下。

你也可以不移动权重文件,那么你就需要在执行第 2 行命令前将 checksum 文件移动到权重文件所在目录。

第 1 行代码

获得所有官方权重文件的 sha256 checksum,并保存成一个标准的 checksum 文件。这行代码需要在 git 仓库目录下执行

git lfs ls-files -l | awk '{print $1"  "$3}' > large_files.sha256

这行命令输出的文件内容形如:

c2388e6b127ce6664e35c5e2529c3ce4bfc99f4f7fb6fa48e92b29ed5e4922af  model-00001-of-000163.safetensors
5f450c75da7eb897b74a092eee65df8bb115fce81cccd2bbaeb220bd97197875 model-00002-of-000163.safetensors
...
913177d9e0dfb228769e0a13a386c34b919dcbb32a430ce230979f53bf7ae5bc model-00163-of-000163.safetensors

第 2 行代码

根据官方权重文件的 checksum,检查本地文件的完整性。这个检查的执行速度会非常慢,因为它需要为每个文件计算 sha256sum,然后再与 checksum 文件做比对。

sha256sum -c large_files.sha256

这行命令的输出形如:

model-00001-of-000163.safetensors: OK
model-00002-of-000163.safetensors: FAILED
...
model-00163-of-000163.safetensors: OK

如果所有行的输出都是 OK,那么恭喜你,所有权重文件都没有损坏;如果有某行输出为 FAILED,就代表该文件没有通过完整性校验,你需要重新下载它。

此方法对所有标记为 LFS 的文件均有效,并不仅限于 *.safetensors 文件,比如量化模型 *gguf 权重文件,也可以同样用此方法校验。

单机 KTransformers 运行 DeepSeek-R1-GGUF 4 bit 量化模型 Q4_K_M 实测

最近有些文章把 KTransformers 吹得没边儿,但是看到的实测案例非常少。我也比较好奇它的实际表现,所以来实测一下看看。

机器配置

硬件实测环境官方案例
CPUIntel(R) Xeon(R) Platinum 8350C CPU @ 2.60GHz, 单插槽 32 核,64 超线程,2 插槽,2 NUMA 节点Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes
内存64GB DDR4 2933MHz x 24,共 1.5 TB 内存standard DDR5-4800 server DRAM (1 TB), each socket with 8×DDR5-4800
GPUNvidia L40S, 48GB VRAM4090D 24G VRAM

实测环境机器配置看起来很强悍,但距离 KTransformer 首页给的官方案例配置还是有差距:

  1. 8350C 是第 3 代至强 CPU,官方案例用的 6454S 是第 4 代至强 CPU,Intel AMX 指令集只在第 4 代和第 5 代至强上支持,号称比前一代有 3~10 倍的推理性能提升;
  2. DDR4 2933 的访存带宽,跟官方案例用的 DDR5 4800,纸面数据差 60%;
  3. 虽说 L40S 比官方案例用的 4090 性能要更强,显存要更大,但目前 KTransformers 给的配置并不能完全发挥出来。

程序环境

基于 Pytorch 镜像 pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel ,本地编译 KTransformers 命令:make dev_install 。

注:通过 pip install 的 KTransformers 包会在运行时 crash 在 cpuinfer.so 里,我猜测是官方的包使用了更高级的指令集,而我这台机器不支持。

执行命令:numactl -N 1 -m 1 python ./ktransformers/local_chat.py --force_think --model_path DeepSeek-R1/ --optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml --gguf_path DeepSeek-R1-GGUF/DeepSeek-R1-Q4_K_M/ --cpu_infer 33 --max_new_tokens 1000

参数说明

--force_think: 强制在开头输出 <think> 标签

--model_path: 来自于 git clone https://huggingface.co/deepseek-ai/DeepSeek-V3,只需要下载代码,不需要下载参数文件

--gguf_path: 下载的量化参数,只需要下载子目录:https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M

--cpu_infer: 粗看代码,KTransformers 是一个线程分发任务,其余线程消费计算任务,所以官方案例这里总是配置的 2 的 N 次方 + 1。但是在实测时,是否 2 的 N 次方 + 1 区别不大。

实测结果

Prompt: 【《背影》全文不含标题】 请简明扼要地回答问题:这篇文章出自哪里?作者是谁?写于什么年代?

图1:生成结果和性能

可以看到性能,生成整篇答案花费了 4 分多钟,prefill 性能 11 toks/s,decode 性能 3.4 toks/s。这个比官方案例的性能慢了一倍多。而且这个回答太啰嗦了,没有遵循 prompt。

图2: CPU 和 GPU 利用率

图 2 是通过 top 和 nvidia-smi 查看的运行时的 CPU 和 GPU 利用率:可以看到内存占用 380 GB,CPU 33 核差不多用满;GPU 占用 13GB 显存,利用率大概 9%。

图3:同样 prompt DeepSeek Chat 回答

上图是 DeepSeek-R1 chat 官网的回答,看到比量化版好很多,至少遵循了“简明扼要地回答问题”的指令。

其它尝试

为了测出来最大性能,我尝试过加大 --cpu_infer 参数到 65、129,尝试过切换 optimize_rules 到 DeepSeek-V3-Chat-multi-gpu-4.yaml 或者 DeepSeek-V3-Chat-multi-gpu-8.yaml,实测都没有看到性能优化,很多甚至还劣化到 0.8 toks/s。

但是降低 --cpu_infer 参数到 25,没有观察到性能劣化。

观察

从 GPU/CPU 的利用率可以看出,KTransformers 主要靠 CPU 推理,AMX 指令集和内存访问速度很关键,GPU 利用率很低,反而不关键。

DeepSeek-R1-Q4_K_M 4 bit 量化模型较非量化模型效果有显著差距,可以观察到指令遵循都不太够。

KTransformers 目前对计算任务的拆分,并没有实现跟随 CPU 核数线性提升性能,这说明也许里面还有很多优化可以做。

讨论

现在大模型动辄几百 GB,需要 N 张显卡才能运行,客观上阻碍了很多感兴趣的人去体验和创新。

KTransformer 能用低配硬件慢速跑起来(接近)满血的模型,是非常赞的一个项目。它后续的持续优化也值得期待。

但 CPU 推理在成本上有它的局限性,大公司也不傻,如果 CPU 成本低,为啥要买那么多 GPU 卡?

核算推理成本要尊重客观事实,别动不动“告别天价显卡”。进行压力测试后,将合理计算的服务器成本平摊到每个请求/token上来计算。从这个角度看,大概有些人又要鼓吹“告别天价 CPU 服务器”了。

DeepSeek-V3 MTP 工程实现思考

一个东西从 idea 到实现,中间有着巨大的鸿沟,DeepSeek-V3 的 Multi-Token Prediction 也一样。虽然开源社区很多在一个多月前已经支持了基于 TP 的 DeepSeek-V3 推理,但是 MTP 部分目前都还在开发中,进展最快的可能是 vLLM,参见 vLLM PR #12755 [5] 。

但我觉得这并不是一个终点,可能还有很多工程优化工作需要继续完成。下面我尽量用浅显的图表和语言来说明我的理解和思考,如有错误也欢迎指出。

Speculative Decoding (投机解码)

理解 MTP 首先要理解 Speculative Decoding,这里不过多介绍,仅用一张图说明 Speculative Decoding 的计算过程,便于理解后续的分析。如果希望深入了解可以观看 这个 Youtube 视频 [1]。

图1:自回归和投机解码示例,来自视频:EAGLE and EAGLE-2 [1]

左边展示的是常规的 LLM 自回归迭代计算过程。

初始 prompt 是 token: "how",how 先经过 embedding 计算变成 ehow,然后经过 Decoder Layer 的计算,输出特征向量 fhow(最后一层 hidden states),经过 LM Head 转换成一个概率分布向量 phow,通过采样得到生成结果 token:"can"。

然后 LLM 会把 can 作为新的输入,进行下一步的计算,直到 LLM 给出推理结束 token:"<EOS>"。

自回归解码是逐 token 迭代进行的,生成一个 token,将 token 作为输入生成新的 token,直到遇到结束条件:生成 "<EOS>" 或者达到最大生成长度。

右边展示的是 Speculative Decoding 的过程。

初始 prompt 还是 how,但是通过其它方式(比如一个小模型,叫做草稿模型)先推测了两个草稿 token:"can、we",同时输入到目标模型。普通的 Decoder 实现仅能解码 1 个 token,这里改造成能够同时解码输出 3 个 token 的 hidden states。这样我们就能同时得到:phow, pcan 和 pwe。然后就可以跟草稿模型输出的 qhow 和 qcan 进行比较,验证是否接受草稿模型的草稿 token:can 和 we。

图上目标模型验证结果是接受 can,但是拒绝 we,那就使用 pcan 采样,得到生成结果 token:I。这就意味着,投机解码通过一次推理,得到了两个 token:can、I实现了1倍逻辑加速:can 是推测以后得到验证的 token,I 是拒绝推测 we 以后,根据目标模型自身输出采样的 token。

EAGLE 与 DeepSeek-V3 MTP

EAGLE 简单说来是作者认为通过目标模型本身的特征向量(就是上面的 fhow)预测下一个 token 更准确,所以草稿模型使用了与目标模型基本相同的结构,利用了目标模型输出的特征向量(fhow)作为草稿模型输入。如果希望深入了解可以观看 这个 Youtube 视频 [1]。

图2:EAGLE 和 DeepSeek-V3 MTP 的区别 [2][3]

MTP 与 EAGLE 不同的点如上图所示,除了多做了一次 Norm 这种细节之外,主要是多步推理的时候的串行。EAGLE 在多步推理时,只使用到了一个草稿模型做自回归推理;MTP 在多步推理时,其实是多个草稿模型进行串行推理

了解完上面这些背景以后,我们可以分析如果希望实现 DeepSeek-V3 MTP,都需要做哪些工作。

MTP 实现

1. MTP 加载

虽然很多框架都支持了 EAGLE,但一般的实现,都只支持 1 个草稿模型。而 MTP 从设计上,需要加载多个草稿模型,每一个 MTP 层,都是一个草稿模型。

在推理的时候,要根据不同的 step,选不同的模型进行推理。这就使得 MTP 草稿模型的加载和推理的调度比其它投机编码要复杂。

但如果 MTP 的步长等于 1,那就相当于 1 个草稿模型,实现会简单很多。

2. MTP Prefill

图3:DeepSeek-V3 MTP [3]

从上图可以看出,第 i 个 MTP Module 的输入 token,是第 i+1 个 token 到第 n 个 token,n 是当前生成的总长度。而它不仅需要 token 的 embedding,还需要 token 在前一个模型计算得到的 hidden states。

比如 MTP Module 1 的输入,是 token 2 到 5 的 embedding 和 main model 最后一层输出的 token 2 到 5 的 hidden states。

这也就意味着,在完成 DeepSeek-V3 的 prefill 时,需要输出最后一层的 hidden states,才能进行第 1 个 MTP 的 prefill;第一个 MTP 输出最后一层的 hidden states,才能进行第 2 个 MTP 的 prefill,以此类推。

可以注意到:多个 MTP 的多次 prefill 计算是串行的。这意味着每增加 1 个 MTP Module,每次推理的时候就要多一轮串行的 prefill,并且多一份 kv cache。一个主模型加 N 个小模型的推理,可能会严重影响计算调度的效率,可能这也是为什么 DeepSeek-V3 只输出了 1 个 MTP Module 的原因。大概他们也认为,仅使用 1 个 MTP Module 性价比最高

3.MTP PD 分离

我在之前一篇博客[4]中列举了 PD 分离背后面临的很多架构选择,MTP 会让 PD 分离变得更复杂。框架有两种选择:

选择一:Prefill 节点做 MTP Prefill:如下图所示,P 节点做完 DeepSeek-V3 Prefill 以后,保留最后一层所有 token(除了第 1 个,即index 0)的 hidden states,采样生成的第一个 token,获得 tokenid,然后将这些输入到 MTP Module 1 做 Prefill。最后将 1) DeepSeek-V3 61 层的 KV Cache; 2) DeepSeek-V3 MTP 的 KV Cache; 3) DeepSeek-V3 生成的第一个 tokenid;4) DeepSeek-V3 MTP 生成的第一个草稿 tokenid 和概率;这 4 部分传给 D 节点。

图4:DeepSeek-V3 MTP Prefill PD 分离计算方案

选择二:Prefill 节点不做 MTP Prefill:P 节点做完 DeepSeek-V3 Prefill 以后,把:1) DeepSeek-V3 61 层的 KV Cache; 2) 最后一层所有 token(除了第 1 个,即index 0)的 hidden states;3)所有 token (除了第 1 个,即 index 0)的 embedding。这 3 部分传给 D 节点。D 节点将生成第一个 token 的 hidden states 经过 LM Head 计算和采样获得 tokenid,然后对 MTP 进行 Prefill。

考虑到通信量和复杂度,大概大家都会选择一,但这样 Prefill 节点就必须加载 LM Head 了,因为 MTP 依赖生成的 tokenid 做 embedding 输入。

4. MTP MLA 算子优化

由于 MLA 的复杂性,现在的很多 MLA 实现并不支持在 decode 单次前向计算时同时并行计算多个 Query token,所以只能通过 Batch Expansion 进行投机解码。

Batch Expansion

以 how [can, we] 举例,我们可以展开成 3 个请求:

图5:投机编码的 Batch Expansion 并行计算方法

从逻辑上来看,请求变多了,但 3 个请求放到一个 batch 中可以进行并行计算,可以共享 prefix cache (如果先做 prefill 的话),这样我们依然可以拿到 phow, pcan 和 pwe。通过并行请求也能够实现 1 次 Decode 验证多个 token。

这里要注意一个逻辑:虽然要验证 2 个 token,但是却展开成了 3 个请求。这样如果全部两个草稿模型投机推理的 token 都被接受了,那第 3 个 token 会由目标模型自己生成这个 token 被称为 bonus token

虽然 Batch Expansion 能解决投机编码时的并行问题,但 Batch Expansion 有一定的计算开销。在高吞吐的时候,会抵消投机编码带来的加速。更好的优化就需要 MLA 算子在单次前向计算时,同时 decode 2 个 query token,这有一定的改造成本。

5. MTP 并行和 overlap 优化

以 DeepSeek-V3 的参数规模,模型并行必不可少。尤其是考虑到微批计算、通信的 overlap 带来的高效率,MTP 的推理未必能像一般的草稿模型一样,单独执行。

很有可能需要将 MTP 的推理和 Speculative Decoding 的打分、验证融入到 DeepSeek-V3 模型中。通过一次前向计算,完成:1) 草稿 token 的打分和验证;2) 生成 token 的输出;3) 新草稿 token 的生成。类似于图 4,我就不画了。

结语

所以个人思考,DeepSeek-V3 MTP 的最优实现方式,很大可能是将 1 层与主模型融合在一起调度,而不是按照独立模型单独执行;在 PD 分离时由 Prefill 节点同时负责 MTP 的 prefill。

[1] EAGLE and EAGLE-2: Lossless Inference Acceleration for LLMs - Hongyang Zhang, https://www.youtube.com/watch?v=oXRSorx-Llg

[2] EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty, https://arxiv.org/abs/2401.15077

[3] DeepSeek-V3 Technical Report, https://arxiv.org/abs/2412.19437v1

[4] LLM PD 分离背后的架构问题, https://yangwenbo.com/articles/reflections-on-prefilling-decoding-disaggregation-architecture.html

[5] https://github.com/vllm-project/vllm/pull/12755

DeepSeek 官方修正了 V3 的激活参数量说明

在之前的博客《DeepSeek V3 模型各子模块参数量精算》中,我计算的模型激活参数量跟官方 README_WEIGHT.md 中的说明对不上。之后有读者跟我说,官方更新了激活参数量的数字。我查了一下 commit history,具体修改如下:

DeepSeek V3 README_WEIGHTS.md commit

可以看到,V3 模型激活参数量从 36.7 改成了 36.6,并且去掉了包含 0.9B Embedding 的说明,那基本上跟我的计算完全对上了。MTP 激活参数量从 2.4B 改成了 1.5B,也去掉了 0.9B 的 Embedding,跟我的计算还是有 0.1B 的差异。

Anyway,这种总量统计只是为了揭示计算的大约规模,有点差异也不影响定性结论。真正有用的是你在拆分 TP、EP 等权重矩阵时,矩阵的形状是多大,要拆多少份,每份大概多大。

为了分析像 DeepSeek V3 这样的超大模型具体参数,我写了一个小脚本,可以将 safetensors 文件里面的权重 Shape 提取出来,并且可以按不同的层级做参数量的聚合计算:

https://github.com/solrex/solrex/blob/master/snippets/show_safetensors.py

#!/usr/bin/env python3
import os
import argparse
import torch

from safetensors import safe_open

def print_tensor_tsv(model_dir, depth):
    '''Print tensor info in .safetensors into tsv format'''
    TENSOR_CLASS = {
        'weight': 'weight',
        'e_score_correction_bias': 'weight',
        'weight_scale_inv': 'scale'
    }
    print('SafetensorsFile\tTensorKey\tTensorParams\tTensorType\tTensorShape')
    safetensor_files = sorted([f for f in os.listdir(model_dir) if f.endswith('.safetensors')])
    summary = {}
    for filename in safetensor_files:
        file_path = os.path.join(model_dir, filename)
        with safe_open(file_path, framework='pt') as f:
            for key in f.keys():
                tensor = f.get_tensor(key)
                print(f'{filename}\t{key}\t{tensor.numel()}\t{tensor.dtype}\t{tensor.shape}')
                lst = key.split('.')
                # Get suffix: .weight or .weight_scale_inv
                tclass = TENSOR_CLASS[lst[-1]]
                # Limit prefix to dep
                dep = min(len(lst), depth+1) if depth > 0 else len(lst)
                # Get summary of prefixes
                for prefix in ['.'.join(lst[:i]) for i in range(0, dep)]:
                    summary[f'{tclass}[{prefix}]'] = summary.get(f'{tclass}[{prefix}]', 0) + tensor.numel()
    for key in sorted(summary):
        print(f'Summary\t{key}\t{summary[key]}\t\t')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Print tensor shape and dtype of .safetensors file')
    parser.add_argument('model_dir', nargs='?', default='.', help='Model directory (default: $PWD)')
    parser.add_argument('--summary_depth', '-d', type=int, default=3, help='Summary depth of weights')
    args = parser.parse_args()
    print_tensor_tsv(args.model_dir, args.summary_depth)

在 HuggingFace 模型根目录下执行 ./show_safetensors.py ,即可获得当前模型的所有权重 Shape 和前 3 层的聚合权重规模。可以通过 “-d” 参数调整最大聚合的层级。输出的文件是 tsv 格式的,可以导入表格进行再计算。

以下是使用 show_safetensors.py 分析 DeepSeek V3 参数量的示例:

$ ./show_safetensors.py -d 2
SafetensorsFile TensorKey TensorParams TensorType TensorShape
model-00001-of-000163.safetensors model.embed_tokens.weight 926679040 torch.bfloat16 torch.Size([129280, 7168])
model-00001-of-000163.safetensors model.layers.0.input_layernorm.weight 7168 torch.bfloat16 torch.Size([7168])
...
model-00163-of-000163.safetensors model.layers.61.shared_head.head.weight 926679040 torch.bfloat16 torch.Size([129280, 7168])
model-00163-of-000163.safetensors model.layers.61.shared_head.norm.weight 7168 torch.bfloat16 torch.Size([7168])
Summary scale[] 41540496
Summary scale[model.layers] 41540496
Summary scale[model] 41540496
Summary weight[] 684489845504
Summary weight[lm_head] 926679040
Summary weight[model.embed_tokens] 926679040
Summary weight[model.layers] 682636480256
Summary weight[model.norm] 7168
Summary weight[model] 683563166464

可以看到第一列为文件名(像 model-00001-of-000163.safetensors)的行是该文件中的具体权重信息,包含 Shape 信息;第一列为 Summary 的行,是根据模型的 tensor key 名字结构, 例如 “model.layers.0.input_layernorm.weight”,按照 “.” 切成前缀,按照前缀聚合模型参数量的结果,不包含 Shape 信息。

LLM PD 分离背后的架构问题

PD 分离(Prefilling Decoding Disaggregation)推理是指将大模型推理的预填充阶段(P)和解码(D)阶段分离,以减少预填充与解码相互之间的影响,以便对两个阶段分别进行优化,提升 GPU 硬件的利用率,并减少推理延迟的一种推理技术方案。

在 DistServe、Mooncake 等论文中介绍分离式架构之后,DeepSeek V3 的报告让大家更进一步意识到 PD 分离可能是影响成本和性能的关键技术。

vLLM 对 PD 分离已经有了一个 1P1D 的实验版本。除此之外的开源框架大多还都不支持,不过很多已经在计划、实现中了。但纵览这些实现、文章或者计划,可以看到 PD 分离的架构选型上有很多问题需要思考,我尝试列举一下:

一、PD 是否直连传输?或是否需要 KV Cache Store/Pool?

PD 直连就是预填充节点直接将 KV Cache 发送给解码节点,它的好处是延迟低。但也意味着在整个 batch 的计算过程中锁定了P、D 节点的对应关系,一旦解码节点出现了问题,比如压力过大、服务出错、传输阻塞,在重试时无法仅调度 D 节点,需要重新进行整个预填充、解码过程。在 prompt 较长时,或者在 PD 节点数不对等的场景下,例如 2 个 P 对应到 1 个 D,重调度意味着抛弃较长或者多个 prefill batch,重调度的沉没成本较高。

使用 KV Cache Store/Pool 是在 P 和 D 之间增加了一个中间存储,预填充节点先将 KV Cache 写到中间存储,解码节点从中间存储读。这样做数据会多传输一次,增加了延迟,也增加了一些复杂度。但好处是容错性更好,还有就是预填充阶段本身也可以利用这个中间存储做 Prefix Caching。

中间存储也会对其它一些架构变动的复杂度产生影响,参见下面问题 四 和 五。

目前来看,Kimi Mooncacke、vLLM 的下一步设计、阿里 RTP-LLM 都使用或者计划使用基于 KV Cache Store/Pool 的方案,DeepSeek V3 报告中没有提到这部分。

在一些计算配比均衡、故障风险较小的场景下,比如同机多卡之间的 PD 分离,PD 直连的方案也有其简单易部署的优势。

二、P/D 是否按层发送/接收 KV Cache?

预填充最简单的实现是预填充节点完成第一个 token 的生成后,将所有的 KV Cache 传输给解码节点,这也是 vLLM 当前的实现。但这样实现有个问题,因为 KV Cache 的规模有可能非常大(尤其是原始 MHA),一个 batch 的 KV Cache 可能会是 GB 级别,都放在计算完成后传输,传输的延迟开销会比较大。

Kimi Mooncacke 和阿里 RTP-LLM 都采取了按层传输的方案,这是利用了 LLM 多层计算的自然特性。在完成一层的计算以后,就将这一层的 KV Cache 发送出去。这样 KV Cache 的发送就呈流式,既能降低延迟,也能使数据的发送更平滑。还存在一个更显著的优势,是 KV Cache 占用显存的时间更短,在显存紧张的情况下显存效率更高。

但按层发送对推理引擎的修改显然更大。我还没有看到开源的实现,猜测按层发送的引入对推理引擎的优化应该会有一定的影响,这里可能还需要一些精巧的设计才能减少影响。另外,按层发送对于 PD 非直连的场景下,中间存储的实现也会显著更复杂,QPS * num_hidden_layers,考虑到连续性可能还需要存储预分配和 session 保持。

因此对于 MLA 这种 KV Cache 偏小的注意力实现,比如 DeepSeek V3 的 KV Cache 是 576B/token/layer,是否要做按层发送,也许要看一下实际收益。

解码阶段和预填充阶段有所不同。解码需要多次迭代,在第一次迭代实现按层解码也没太大意义,而且涉及到计算的编排,应该需要拿到所有层的 KV Cache 才会开始计算。而且解码的计算时间比较长,如果解码的计算能够掩盖接收的延迟,不一定非要实现按层接收。

解码时按层接收,对调度也有一定挑战。从时序上来说,先发请求给预填充,完成后再发请求给解码会更自然。同时请求预填充和解码,需要处理一些同步问题,比如预填充压力大、解码等 KV Cache 超时等等。比如像阿里 RTP-LLM,它会观测预填充的排队情况,当一个请求进入预填充执行阶段时,解码端开始启动显存申请。

三、First Token 怎么处理

通常来说,预填充的同时会顺便把第一个 Token 计算出来,但计算到 hidden states 还是 token id 需要做一个选择。

计算到 hidden states 的好处是,预填充节点完全不需要加载和计算 lm_head 参数。比如 DeepSeek V3 的 lm_head 参数量是 0.9B,如果计算到 hidden states,这部分参数就完全不需要加载了。vLLM 目前就是采取的这个方式,预填充除了需要发送 KV Cache 之外,还需要发送一个 hidden states,解码时引擎也需要能支持加载 hidden states 延续计算。

计算到 token id 的好处是,发送的数据量小。以 DeepSeek V3 为例,hidden states 7K,token id 4B,完全可以跟着控制面消息传输。解码时引擎处理也更简单,因为 token id 到 token 的 detokenizer 一般是 CPU 查表,不涉及 tensor 的特殊处理。阿里 RTP-LLM 看起来采用的是这个方案。

四、Prefiller 和 Decoder 是否能相互转换?

当到达请求的 prompt 长度有差异性的时候,预填充和解码就会出现压力的不均衡问题。因为整体的吞吐取决于 P 和 D 的全局资源利用,当 P 过载但 D 闲置,或者 P 闲置但 D 过载的时候,成本和性能都不是最优的。

所以就需要考虑在 P 和 D 之间做负载均衡,要么从整个节点层面直接切换 P 和 D 的角色,要么 P 和 D 节点能够承担一些混杂的请求,比如通过 chunked prefill。

这时候 P 和 D 是否直连对实现复杂度就有一些影响了,如果有中间存储的存在,通过 PD 转换做负载均衡的实现难度会降低很多。

五、Decoder 能填充 KV Cache 吗?

如果业务应用场景中会将生成的 context 也作为下一轮的输入,还可能需要考虑 Decoder 填充 KV Cache,用于下一轮的 prefix caching 复用。这时候,KV Cache Store/Pool 的存在,对流畅交互有比较大的意义。

六、KV Cache Store/Pool 的设计抉择

有别于我们通常的 KV 存储,由于 GPU、RDMA(IB、RoCE)、NVLink 新硬件的存在,KV Cache Store/Pool 的设计抉择点会非常多。

在存储上,有 VRAM、DRAM、NVMe SSD,要选择 KV Cache Store 使用哪些介质。虽然对于 MHA 来说,因为 KV Cache 太大,基于 SSD 存储并不现实,但是对于 MQA、MLA 来说,NVMe SSD 并不是不可用。

在通信上,有 TCP、NVLink、RDMA、GPU Direct RDMA、NVMe over RDMA。为了更高的性能,KV Cache Store 在数据面上可能要考虑使用更快、更直接的传输方法。但 RDMA 对数据访问的抽象比 TCP 复杂很多,TCP 就是一端发一端收,但 RDMA 很多是单边操作。比如数据从 A 机 VRAM 发送到 B 机 DRAM,可能有以下方法:

  • A 从 VRAM 复制到 DRAM 再写 B 的 DRAM
  • A 从 VRAM 复制到 DRAM 再让 B 读 A 的 DRAM
  • A 直接从 VRAM 复制到 B 的 DRAM
  • B 直接读 A 的 VRAM

如果再加上 NVMe over RDMA,那要考虑的东西就更多了。P 发送到 Store,D 从 Store 接收,到底要通过哪些模式支持,是需要思考的。目前来看,预填充节点更适合单边写到 Store,这样能减少状态传输,更快地释放显存,但如果预填充节点也要读 prefix cache,那情况可能反过来;解码节点可能更适合单边读 Store。

在分布式架构上,无论是做集群式的 KV Cache Store,还是单机 side-car 式的 KV Cache Store,都需要存储一些 meta,并且在 P、D 之间传输一些控制信息。学术界有一些完全基于 RDMA 实现的分布式 KV 数据库,但目前看复杂度还是比较高,也没有开源的实现。目前业界实现还是倾向于使用传统的 RPC 方式来传输控制信息,并且通过分布式技术方案做 meta 节点的一致性、可靠性设计。

在接口 API 上,KV Cache Store 比传统的 KV Store 要复杂一些。比如要支持写的时候分 layer 写,读的时候能读到连续的内容;还可能要支持队列式的读,写完的 layer 可以很快被读走。如果要支持 prefix caching,还存在 KV Cache 的链式关系,写的时候不仅要分 layer,还要分 page,读的时候也是。TP/SP 等并行计算机制,对 API 可能还会有一些额外的要求。

在数据结构上,如果希望从 VRAM 直接写 Store,减少一次复制,引擎本身的 KV Cache 数据结构就需要与 Store 的数据结构进行一定程度的对齐;如果希望同时兼做 prefix caching,那 store 的数据排布就要考虑相同 prefix 的 page 更接近,甚至共享。比如用 prompt 的所有 page 的 hash 组成 string,按前缀 range 分桶,桶内对相同前缀做 merge/引用等等,这在存储优化上会是一个挑战。

整体来看,PD 分离的实现上有很多架构问题需要抉择,目前还没有一个理想的架构方案,或许未来也会是根据不同场景有很多参数化的灵活配置。

DeepSeek V3 模型各子模块参数量精算

网上很多文章一般只提到 DeepSeek V3 模型的总参数量,很少有人分析各子模块的参数量。我试着让各 AI 根据配置计算一下,没有一个靠谱的,只能自己算了。(也许本文的内容后续会变成 AI 回答本问题的 RAG 养料)

下面是根据 DeepSeek V3 开源仓库 https://huggingface.co/deepseek-ai/DeepSeek-V3,对 DeepSeek V3 各子模块参数量进行的精算,在计算复杂的 TP、DP、EP 拆分时可以用作基数参考。如有错误,烦请评论指出。

嵌入层 Embedding

"vocab_size": 129280, // Token 字典大小
"hidden_size": 7168,

DeepSeek V3 的嵌入层参数量是:

129280 * 7168 = 926,679,040 (~0.9B)

MLA

"hidden_size": 7168,
"num_key_value_heads": 128,
"v_head_dim": 128,
"kv_lora_rank": 512,

"num_attention_heads": 128,
"q_lora_rank": 1536,

"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,

"num_hidden_layers": 61,

单层 MLA 中 Q 的 LoRA 参数量是:

7168 * 1536 + 1536 + 1536 * 128 * (128 + 64) = 48,760,320

单层 MLA 中 KV 的 LoRA 参数量是:

7168 * (512 + 64) + 512 + 512 * 128 * (128 + 128) = 20,906,496

单层 MLA 中 WO 的参数量是

128 * 128 * 7168 = 117,440,512

pre 和 post attention layernorm 的参数量是:

7168 * 2 = 14336

所以 DeepSeek V3 的 MLA 部分共 61 层的总参数量是:

(48,760,320 + 20,906,496 + 117,440,512 + 14336) * 61 = 11,414,421,504 (~11B)

MoE

"num_hidden_layers": 61,
"hidden_size": 7168,
"moe_intermediate_size": 2048, // 路由专家 MLP 的中间维度
"n_shared_experts": 1, // 共享专家数量
"n_routed_experts": 256, // 路由专家数量
"first_k_dense_replace": 3, // 前几层使用dense替换MoE
"intermediate_size": 18432, // 前3层 (9*moe_intermediate_size)

每个专家的参数量是:

7168 * 2048 * 3 = 44,040,192

路由 Gate 的参数量是:

256 * 7168 + 256 = 1,835,264

前 3 层 dense(固定激活 8 路由专家),前 3 层参数量是:

44,040,192 * 9 * 3 = 1,189,085,184

后 58 层稀疏(动态激活 8 路由专家),后 58 层参数量是:

(44,040,192 * 257 + 1,835,264) * 58 = 656,569,547,264

所以 DeepSeek V3 的 MoE 部分的总参数量是:

1,189,085,184 + 656,569,547,264 = 657,758,632,448 (~657B)

每次计算激活 1 个共享专家,8 个路由专家,所以 DeepSeek V3 MoE 部分的激活参数量是:

44,040,192 * 9 * 61 + 1,835,264 * 58 = 24,284,510,720 (~24B)

Layer 维度

前 3 层是 dense,没有 gate,基于上面的计算,DeepSeek V3 前 3 层每层的参数量是:

(48,760,320 + 20,906,496 + 117,440,512 + 14336) + (44,040,192 * 9) = 583,483,392

后 58 层是 MoE 稀疏激活专家,基于上面的计算,DeepSeek V3 后 58 层每层的参数量是:

(48,760,320 + 20,906,496 + 117,440,512 + 14336) + (44,040,192 * 257 + 1,835,264) = 11,507,286,272

输出层

DeepSeek V3 输出层的 RMSNorm 和 Linear 参数量是:

7168 和 129280 * 7168 = 926,686,208 (~0.9B)

总参数量

核对一下 DeepSeek V3 总参数量是否为 671B:

583,483,392 * 3 + 11,507,286,272 * 58 + 926,679,040 * 2 + 7168 = 671,026,419,200 (~671B)

核对一下 DeepSeek V3 激活参数量是否为 37B:

11,414,421,504 + 24,284,510,720 + 926,679,040 * 2 + 7168 = 37,552,297,472 (~37B)

这个与 README_WEIGHT.md 中提到的 36.7B 不同,我还没找到计算错误的地方。我理解也许是考虑到 embedding 层只是查表,并不是矩阵乘,所以实际激活参数是:

11,414,421,504 + 24,284,510,720 + 926,679,040 + 7168 + 7168 = 36,625,625,600 (~36.6B)

PS (20250208)

DeepSeek 更新了 README_WEIGHT.md ,激活参数量修正成了 36.6B,也去掉了包含 0.9B 输入 embedding 的注释。

MTP

DeepSeek V3 MTP 的 ebedding 和输出 head 与主模型共享,enorm 和 hnorm 的权重是:

7168 + 7168 = 14336

eh_proj 线性变换的权重规模是:

7168 * 14336 = 102,760,448

增加了一层 hidden layer,即第 61 层:

(48,760,320 + 20,906,496 + 117,440,512 + 14336) + (44,040,192 * 257 + 1,835,264) = 11,507,286,272

加起来 DeepSeek V3 MTP 的总参数量是:

11,507,286,272 + 102,760,448 + 14336 = 11,610,061,056 (~11.6B)

DeepSeek V3 MTP 的激活参数量是:

11,610,061,056 - 44,040,192 * (256 - 8) + 926,686,208 * 2 = 2,541,465,856

这个规模比 README_WEIGHT.md 中提到的 11.5B 独立参数,和 2.4B 激活参数都略大一点。

PS (20250208)

DeekSeek 更新了 README_WEIGHT.md ,MTP 的激活参数量由 2.4B 改成了 1.5B,可能跟上面的激活参数一样,都减去了 embedding 层。但在我的计算里,这个应该是 1.6B :),还是略有不同。

11,610,061,056 - 44,040,192 * (256 - 8) + 926,686,208 = 1,614,779,648 (~1.6B)

DeepSeek V3:AI 大模型 infra 基建新高度

AI 工程化

2021 年初,贾扬清在阿里云开始推广 AI 工程化这个概念,我非常认同。我在 21 年中做技术规划的时候,提出“AI 到生产力的转化,需要更高的工程化能力”,并且将 AI 工程化的实施总结为几大方向:

  • 语义索引场景化
  • 算力调度混合化
  • 模型研发标准化
  • 优化技术普惠化
  • 模型超大规模化
  • 架构系统智能化

我的 AI 工程化团队在这些方向上也取得了许多成果。

The AI Model

但 2022 年底 LLM 大流行以后,情况发生了一些变化。原因主要是 LLM 让 AI models 变成了 The AI model,虽然这个 model 很大,也多多少少有一些变种,但从工程实践的角度来看,它并不“复杂”。

很多时候,工程架构解决的是复杂性问题。

比如,TensorFlow、PyTorch、PaddlePaddle 这些训练框架简化了搭建和训练神经网络的复杂度,在一段时间内,各种结构的网络层出不穷,大部分都是依托这些框架来实现的。

而对于 LLM 来说,模型结构相对固定,虽然也使用了框架的一些外围能力,但是模型结构核心部分已经逐渐变成全手写以达成最佳性能,典型的实现包括 FlashAttention、TRT-LLM 等。

而且 LLM 的接口调用是自然语言,因而也变得极其简单,所有的 LLM 模型几乎可以使用同一套 API。

当时看起来 LLM 不需要太多的架构基建工作。

Prefix Caching 和 Kimi

我的这个认知在思考 prefix-caching 作用的时候,有了一些改变。

在《应该把 Prefix Caching 当作一种效果优化技术》这篇博客中,我提到 Prefix Cache Aware Scheduling 是一件非常值得做的事情。而且从 Kimi 发表的论文来看,他们已经在实践了,但其它的技术报告提到这些工程架构工作的不多。

DeepSeek V3

前几天 DeepSeek AI 发布了 DeepSeek V3 版本,我一边在吐槽这 670B 的模型参数太大,下载太慢,一边在阅读它的技术报告。结果发现他们在模型的部署上,玩得更高端,给了我一些新的震撼。

首先,prefilling 和 decoding 分开部署。prefilling 4 机 32 卡,decoding 40 机 320 卡。这样一来,我之前《LLM 推理优化 Continuous Batching 及其实现》这篇博客中提到的 Continuous Batching 就不再需要了。两阶段分开后,prefill 的计算过程(长度)是确定的,其算力利用是充分的,不再需要中间停下来插入新的请求。其实 prefilling 能够分开部署,跟 DeepSeek 以前的研究也是分不开的,DeepSeek V2 引入的 MLA 对 KV Cache 做了大幅度的低秩压缩,可以显著降低 KV Cache 从 prefilling 节点传递到 decoding 节点的带宽和延迟。

其次,MoE 专家分开部署。因为 MoE 专家的激活是 Token 级别的,也就是说每个 Token 会决定走哪个专家进行计算,分开部署就可能会带来负载均衡问题:有些专家太忙,有些专家太闲。DeepSeek V3 为了解决这个问题,还做了复杂的负载均衡策略。例如:快速识别较忙的专家,部署冗余的专家副本以承担压力;重新调整专家在不同节点的布局,尽量利用跨 GPU 带宽而减少跨节点带宽(因为 IB 比 NVLink 要慢);单卡冗余部署多专家,但通过全局的路由计算来优化专家的动态激活数量。

DeepSeek V3 还实现了计算和通信重叠。为了掩盖分布式计算过程中进行集合通信时的开销,将计算任务分为微批。一个微批进行集合通信时,进行下一个微批的计算。

此外,DeepSeek V3 在推理时还将 TP(Tensor)、DP(Data)、SP(Sequence)、EP(Expert)不同维度的并行化融合到了一起。单拿出来一种并行化方法也许现在的框架还是支持的,但这些方法组合在一起,我怀疑目前也没有什么推理加速框架能直接处理。

从技术报告中揭露的这些细节可以看出,为了发挥出模型的极致性能,DeepSeek 在 AI 大模型的分布式部署上花费了很大的心思。这也让 DeepSeek V3 成为目前公开资料可以看到的最复杂、最精巧的大模型 infra 设计

这些 idea 以前也许不是没有人想到,但是 infra 的演进是有很高研发和试错成本的。当 DeepSeek 将这些路走通以后,也许未来的很多大模型公司,大模型框架,都会往沿着这个方向继续演进。

应该把 Prefix Caching 当作一种效果优化技术

我在 4 月份写的博客《LLM 推理优化 Prefix Caching 及其实现》带来了不少流量,这篇博客在 Google 搜索“Prefix Caching”时排在比较靠前的位置,也给我的微信公众号“边际效应”带来了超过 100 个关注者,由此可以看到大家对这项技术的广泛兴趣。

新认知

最近关于 Prefix Caching 我又产生了一次认知升级:在一个 Eureka 时刻,我忽然领悟到 Prefix Caching 实际上可以是一种效果优化手段——但让其充分发挥效能需要跟长上下文语言模型 Long-Context LLM 和基于 Prefix Cache 的调度技术相结合

RAG v.s. Super Long Domain Specific Prefixes

为了减少模型的幻觉,提升时效性,我们往往会采取 RAG 技术来增强模型。召回的结果本身会增加 Prompt 的长度,会加大 Context 的计算量,会影响模型的推理时长,因而召回结果的补充也是谨慎克制的,不会让 Prompt 变得非常长。

但是 Prefix Caching 优化技术给我们开了个口子:如果把信息放在 Prompt 的共享 Prefix 中,加长 Prompt 的代价就由计算和时延,转化到了存储。而存储代价可以通过 Cache 复用来摊薄,这就是一笔很划算的经济账。

如果把 RAG 召回的结果当作 Prompt 里的变量,那么 Prefix 就是 Prompt 里的常量,当增加变量信息的规模受到限制时,增加常量信息的规模也可以提升生成的效果。

举个例子:如果你做 K12 领域的模型,把 12 个年级的语文知识点都放在 Prefix 里,然后再叠加 RAG,然后回答用户语文方面的提问,肯定能更好地保证生成的效果。

所以,LLM 技术后续的一个新发展分支,也许会是超长特定领域专用的前缀设计

Inference Time Scaling

最近讨论很多的一个技术方向,是从基于预训练提升效果的 Scaling Law 转向基于 reasoning at inference time 的 Scaling Law,使用更多的推理计算来获得更好的效果。

但这些 Scaling Law 还是基于“computing”的 Scaling,我认为也许会存在基于“memory”的 Scaling Law,即更长的共享 Domain Specific Prefix 带来更好的效果。因为这里的“memory”是计算的缓存,memory 的共享本质上是计算的共享。

Long Context Large Language Model

在无法共享计算的场景下,长上下文的大语言模型应用往往由于其计算成本或显著延迟而无用武之地。但如果基于 memory 的 Scaling Law work 的话,那长上下文大语言模型将成为必经之路

这也许是 Moonshot 和 Gemini 早早投入百万 Token 级别长上下文模型的背后逻辑

Prefix Cache Aware Scheduling

在没有理想的大一统 Prefix 之前(大概率也不可能),共享 Prefix 只能是 Domain Specific 才最合适。那就意味着会有很多个共享 Prefix,而 Prefix Cache 占用的空间又是不可忽视的,所以不可能在每张卡/每台机器上都存储所有的共享 Prefix。

这就要求在用户请求调度时,根据 Prefix 分配到不同的卡/机器/集群上。Prefix Cache 也需要能够根据请求的热度在不同的服务器间弹性扩散或者收缩,甚至在显存、内存、SSD 之间调度。其实 Kimi 的论文“Mooncake: A KVCache-centric Disaggregated Architecture for LLM Serving”就是在解决这个问题,只是我之前没有意识到这个技术的广泛重要性

估值最高的 AI 搜索独角兽 Perplexity 使用倒排索引做 RAG

3 月份我曾写了一篇博客《在 LLM 时代我们是否还需要倒排索引? 》,探讨了在向量数据库崛起的情况下倒排索引仍然存在的价值。

也许仍然有人对传统搜索技术弃如敝履,但前几天我正好看到一个 Lex Fridman 对 Perplexity CEO Aravind Srinivas 的采访,他表示 Perplexity 仍然使用(可能不是仅使用)倒排索引和 BM25 传统算法进行全网搜索。作为估值 30 亿美元的 AI 搜索独角兽公司,在全网搜索上却也选择了倒排索引,而且访谈中非常强调倒排的重要性,这应该算是印证了我的观点吧。

他的这些思考对我也有一些启发,为了保持互联网的记忆,我将这些相关节选的原文和翻译记录如下:

Aravind Srinivas 2024年6月20日采访节选

这个 Podcast 有 3 个小时,可以从 2:03:52 开始看。

Lex Fridman (02:03:52)

这是一个通过(将内容)嵌入到某种向量空间实现的完全的机器学习系统吗?
Is that a fully machine learning system with embedding into some kind of vector space?

Aravind Srinivas (02:03:57)

它并不是纯粹的向量空间。并不是说内容一旦被获取,就会有某种 BERT 模型在所有内容上运行并将其放入一个巨大的向量数据库中进行检索。
It’s not purely vector space. It’s not like once the content is fetched, there is some BERT model that runs on all of it and puts it into a big, gigantic vector database which you retrieve from.

这并不是那样的,因为将网页的所有知识打包到一个向量空间表示中是非常困难的。首先,向量嵌入在处理文本时并没有想象中的那么神奇。理解一个文档是否与一个特定检索词相关非常困难。它应该是与检索词中的人物有关,还是与检索词中的特定事件有关,更或者它是基于对检索词更深层次的理解,使得检索词可以应用到另外一个人物,他也需要被检索到?什么才是(向量)表示应该捕捉的内容?大家可能会不断争论。而让这些向量嵌入具有不同维度,彼此解耦并捕捉不同语义是非常困难的。
It’s not like that, because packing all the knowledge about a webpage into one vector space representation is very, very difficult. First of all, vector embeddings are not magically working for text. It’s very hard to understand what’s a relevant document to a particular query. Should it be about the individual in the query or should it be about the specific event in the query or should it be at a deeper level about the meaning of that query, such that the same meaning applying to a different individual should also be retrieved? You can keep arguing. What should a representation really capture? And it’s very hard to make these vector embeddings have different dimensions, be disentangled from each other, and capturing different semantics.

顺便说一句,这只是排序部分的内容。还有索引部分——假设你有一些处理后的 URL,排序模块会基于你提问的检索词,从索引中召回相关的文档和一些打分。这就是为什么,当你的索引中有数十亿个页面,而你只需要前 K 个结果时,你必须依赖近似算法来获取这些前 K 个结果。
This is the ranking part, by the way. There’s the indexing part, assuming you have a post-process version for URL, and then there’s a ranking part that, depending on the query you ask, fetches the relevant documents from the index and some kind of score. And that’s where, when you have billions of pages in your index and you only want the top K, you have to rely on approximate algorithms to get you the top K.

Lex Fridman (02:05:25)

所以这就是排序,但是将页面转换为可以存储在向量数据库中的内容,这一步似乎非常困难。
So that’s the ranking, but that step of converting a page into something that could be stored in a vector database, it just seems really difficult.

Aravind Srinivas (02:05:38)

它并不必须全存在向量数据库中。你可以使用其他数据结构和其他形式的传统检索。有一种算法叫做 BM25,它正是为此设计的,是 TF-IDF 的一个更复杂的版本。TF-IDF 是词频乘以逆文档频率,是一种非常老派的信息检索系统,实际上即使在今天也仍然非常有效。BM25 是它的一个更复杂的版本,仍然在许多排名中击败大多数嵌入。当 OpenAI 发布他们的嵌入时,围绕它有一些争议,因为在许多检索基准测试中,它甚至没有击败 BM25,这并不是因为他们做得不好,而是因为 BM25 太好了。这就是为什么仅靠纯粹的嵌入和向量空间并不能解决搜索问题。你需要传统的基于词项的检索,你需要某种基于 NGram 的检索。
It doesn’t always have to be stored entirely in vector databases. There are other data structures you can use and other forms of traditional retrieval that you can use. There is an algorithm called BM25 precisely for this, which is a more sophisticated version of TF-IDF. TF-IDF is term frequency times inverse document frequency, a very old-school information retrieval system that just works actually really well even today. And BM25 is a more sophisticated version of that, that is still beating most embeddings on ranking. When OpenAI released their embeddings, there was some controversy around it because it wasn’t even beating BM25 on many retrieval benchmarks, not because they didn’t do a good job. BM25 is so good. So this is why just pure embeddings and vector spaces are not going to solve the search problem. You need the traditional term-based retrieval. You need some kind of NGram-based retrieval.

Aravind Srinivas 2023年12月14日采访节选

其实在去年 Unsupervised Learning 的 Jacob Ephron 采访 Aravind Srinivas 时,他也有过类似的表述,但是没有像最新一次采访那样强调倒排的重要性。这个 Podcast 也很长,可以从 28:08 开始看。

Aravind Srinivas (28:08)

很多人认为,既然我们在网页搜索的 RAG 方面如此擅长,那么 Perplexity 就能轻松搞定内部搜索。不,这是完全不同的两回事。
A lot of people think that because we are so good at RAG for web search, Perplexity will just nail internal search. No, it's two completely different entities.

Google Drive 的搜索之所以糟糕是有原因的。你可能会想,Google作为网页搜索的王者,怎么会这么差劲?他们之所以差劲,是因为索引机制完全不同,需要训练的嵌入模型也完全不同。这不仅仅是嵌入模型的问题,还包括你如何对页面进行摘要、如何进行文本召回——如何使用传统的基于TF-IDF的倒排索引来构建弹性索引,这些都大不相同。
There is a reason why search on Google Drive sucks. Like would you expect Google the king of web search to be so bad? They're bad because of a reason that the indexing is so different, the embeddings that you got to train are so different. Not just the embeddings, but even the way you snippet a page, your text retrieval——the elastic index that you're building with traditional TF-IDF based inverted indexes are so different.

所以需要某家公司专注于这种场景,就像我们专注于网页搜索场景一样。RAG 是一项非常艰巨的任务,在生成式 AI 之外还有很多工作需要完成。这不仅仅是训练一个大型的嵌入模型就能解决的问题。
That you need a company to just obsessively focus on that use case. Just like how we are obsessively focused on the web search use case. So RAG is going to be pretty hard and there's a lot of work that needs to be done outside of generative AI. It's not just training a large embedding model and you're done.

我记得当 OpenAI 发布嵌入 API 时,Sam Altman(OpenAI的CEO)在推特上说,下一个万亿公司可能只需要接入这个 API 就可以建立起来。这虽然是一种很好的营销方式,但事实并非如此。他当时说的是一万亿美元。所以,当听到有人声称他们已经解决了 RAG 问题时,我会非常谨慎。他们可能只是在某个场景下做得很好。
I remember like when OpenAI releases embedding API, Sam Altman tweeted the next 1 trillion company can just plug into this API and be built. That's not true. It's a good way to market it. But that's not true at all. Sorry, he said 1 trillion dollars. So I think that's why I would be very careful when somebody makes claims that they've solve RAG. They probably can do it really well for one use case.

此外,在排名中还有很多其他因素需要考虑,才能使答案真正出色。即使 LLM 最终决定了哪些链接用于答案,你也不能仅仅将一些垃圾信息放入 prompt 中,就指望它能神奇地只选择最相关的链接,并在答案中给出引用。
You know and also there are so many more things to handle in the ranking. That'll make the answer really good. Cuz even though the LLMs are finally the ones that are taking which links to use for the answer, it's not like you can just dump garbage into the prompt and it'll just be magically so good that it'll only take the top most relevant links in the answer and give you the citations with them.

事实上,你向这些长上下文模型提供的信息越多,最终出现幻觉的可能性就越大。因此,你实际上需要在检索模块上投入大量工作,不仅仅是嵌入向量,在索引、嵌入向量和排序上都要投入。排序也应该包含很多除了向量内积之外的其他信号,具体是哪些信号事实上取决于你的场景。
In fact the more information you throw at these really long context models, the more chances that you have a hallucination at the end. So you actually have to do a lot of work in the retrieval component, not just the embeddings, the indexing, the embeddings and the ranking. Ranking should also have a lot of signals outside of just the vector dot products. And then what is those signals are really depend on your use case.

分析(在2024年8月)

从以上不同时期 Aravind 的公开披露的信息分析,几乎可以说 Perplexity 在当前时间点,在召回阶段主要(如果不是全部的话)依赖倒排索引,在排序阶段会用到嵌入向量和其它信号,并且他们很重视除了嵌入向量之外的其它信号。

LLM 推理优化 Prefix Caching 及其实现

Prefix Caching

上文中提到,Prompt 计算(Prefill 阶段)和生成阶段的计算特性很不相同。为避免重复计算,所有框架的 prefill 阶段主要作用就是给迭代的生成阶段准备 KV Cache。但这些 KV Cache 仅仅是为单次生成式请求服务的,那很自然的一种想法就是,KV Cache 能不能跨请求复用?

在某些场景下,多次请求的 Prompt 可能会共享同一个前缀(Prefix),比如拟人 Agent 的人物设定,文档阅读理解时的文档内容等。这些情况下,很多请求的前缀的 KV Cache 计算的结果是相同的,像一般互联网服务的请求缓存一样,可以被缓存起来,给下一个请求复用。

限制

但 KV Cache 跟其它服务缓存不一样的地方是,它太大了,以至于(目前)很难通过 Redis/Memcache 这种分布式缓存服务存取。比如对 13B LLM 模型来说,在 FP16 精度下单 token 的 KV Cache 大约是 1MB,假设要缓存的前缀有 500 个 token(大约800多个汉字),那就是 500MB。一般来说,我们不会每次请求去从分布式系统里读取/传输 500MB 的缓存,甚至都不会每次请求从内存往显存中拷贝 500MB 的缓存,所以大部分情况下,prefix cache 都会放在显存里。

这也就意味着,如果你想命中 prefix cache,必须把相同 prefix 的请求发到同一张 GPU卡上才行。

实现

由于不是普遍需求,加上前面说的限制,prefix caching 作为一个加速特性,不是很受关注,一般也不是默认开启的。各框架的实现和配置略有差异,这里简单做下记录,便于回顾。

刚开始 vLLM 的实现是给 generate 接口增加一个 prefix_pos 参数,通过 prefix_pos 输入参数为每个请求指定 prefix 长度,为 prefix 建一个带淘汰的哈希缓存。后来觉得这样做使用上不够便利,升级成了自动前缀缓存,即将 prompt 的 kv cache 分成 block,然后为 block 建设 LRU 缓存机制,这样就不必在接口上使用 prefix_pos 指定哪部分是 prefix 了。自动前缀缓存功能默认是不开启的,开启的配置项为 --enable-prefix-caching

TensorRT-LLM 与 vLLM 后来的实现类似,也是实现了 block kv cache,配置项是 enableBlockReuse,默认也是不开启的。代码未开源,无法看到实现。

Lmdeploy 的 PythonTurboMind C++ 版本的 prefix caching 功能都已经有了 PR,但现在(20240425)看还没有合入主干。有意思的是它没有使用 hash block 对应 token_id 子串的所有 token_id 前缀然后组成哈希表的方式,而是用 hash 当前 block 对应的 token_id 子串然后组成 trie 树的缓存管理结构。默认的参数名与 vLLM 相同,也叫做 --enable-prefix-caching。

HuggingFace TGI 现在看起来还没实现 Prefix Caching 功能。

Prompt Caching

除了 Prefix Caching 这种比较直观的工程优化,现在也有一些研究在看 Prompt 的其它缓存机制。比如设计一种机制让 prompt 模块化,不仅可以复用 prefix,还能复用中间的部分;或者通过 query 的相似性复用其它 query 的 prompt

但目前看实现上都过于复杂,比如第一种要求模型使用不连续的 poition_id,这样就可以插入 token,但这种方式对 attention 的计算机制有一定的影响,难以说明它对效果的影响。

LLM 推理优化 Continuous Batching 及其实现

原理

Continuous Batching 是 LLM 推理优化的一项技术,作为这篇文章的知识背景不再赘述,目前流传最广的参考资料是这篇:《How continuous batching enables 23x throughput in LLM inference while reducing p50 latency》。它也有中文翻译,感兴趣可以搜一下,先看看。

图片来自:Anyscale 博客

虽然这篇资料介绍了它的主要原理,尤其是这张简单易懂的图,但是实现与原理是存在差异的,因为工程实现要解决很多现实问题。

(Re)scheduling 开销

如原文所说,Continuous batching 还有个别名,叫做:batching with iteration-level scheduling,这里的 iteration 就是指一次 decode 计算。也就是说在每次 decode 的迭代过程中,做 batch 的调度调整。

但调度本身不是无代价的,它可能涉及到接收和处理新的输入请求,重新组织输入数据的形状,甚至各种状态的重新初始化,这些都需要消耗 CPU 时间。这也就意味着在这段时间里,GPU 是闲着的,GPU 没有得到充分利用。

所以在实现时,程序并不会真的每个 iteration 都做 scheduling,目前看到有两种做法:

  • 合理间隔调度。比如每 16 次 decode 计算后,检查一下是否有新的输入,以及是否有空闲的槽位,然后对 batch 做一次调度调整。这能够显著降低调度的开销(TGIlmdeployvLLM)。
  • 排队比例调度。比如当前 batch 中有 10 个请求的 decode 正在进行,而排队中有 12 个请求,超过了排队比例 1.2,那么就启动一次调度调整(TGI)。

KV Cache 重读

如果真的像图中那样,每个生成 Token 的 decode iteration 与一个 prompt token 的计算对齐,那 KV Cache 的利用就会很糟糕。因为它们需要在 Global Memory 与 Shared Memory 之间反复搬运,而且每次搬进来以后只被使用一次。

这本质上是 prefill 阶段(prompt 计算)与生成阶段的计算特性不同导致的,prefill 阶段并不适合 one-by-one 的 token 计算。

所以在实现时,程序并不会真的做 prefill 和生成的 token 对齐调度。目前看到的调度方法有三种:

  • 在重调度时,如果有新的请求进来,那么将新请求的 prefill 计算和要进行的 decode 计算做合并(Orca、vLLM-prev)。
  • 在重调度时,如果有新的请求进来,那么先对新的请求做 prefill 计算,然后再合并所有进行中的请求做 decode 计算(TGI、vLLM、lmdeploy)。
  • 先根据 decode 耗时估算出来每次 decode 同时能做 prefill 的 token 数量,在重调度时,如果有新的请求进来,对新请求的 prefill 按照上面的估算进行分块,然后将分块的 prefill 和其它请求的 decode 计算融合在一起,一定程度上减少了 KV Cache 重读的次数,又避免先做 prefill 计算带来的 Token 生成延时增加(Sarathi-Serve+vLLM)。

可调优能力

LLM 推理服务往往用于生产环境,而生产环境面临的情况是复杂多样的。

  • 对于做阅读理解的应用来说,Prompt 可能会非常长,但生成的内容可能会非常短,开发人员可能会更追求吞吐;
  • 对于聊天应用来说,Prompt 可能较短,生成的内容也不会太长,开发人员可能会更追求延迟;
  • 对于创作类应用来说,Prompt 可能很短,生成的内容会更长,开发人员可能会更追求首 Token 延迟。

对 Continuous Batching 实现来说,就要求它调度策略尽量清晰,并且参数可调。所以更灵活的实现,未来可能会更受欢迎。

Logits of API-Protected LLMs Leak Proprietary Information

看到一篇挺有意思的论文,大开脑洞,没想到还能这么玩,做一下粗读的笔记。

论文

标题:《Logits of API-Protected LLMs Leak Proprietary Information》,链接: https://arxiv.org/pdf/2403.09539.pdf

假设条件

典型 LLM 需要将最后一个 Transformer 块输出的嵌入向量转成要输出的 Token,这一步往往通过一个线性变换加 softmax 获取那个最大概率的 tokenid。

线性变换的权重是一个 vocabulary size * hidden size 的矩阵,比如 llama2-7B 的词表大小是 32000,hidden size 是 4096,那么线性变换权重矩阵的尺寸就是 32000x4096。这个矩阵再与 4096x1 的嵌入向量相乘,得到的就是 32000x1 的 logits 向量,其中每一个元素对应着一个词表中的 token 作为最终输出的概率。

上面这只是假设,也许 GPT 使用的是一个非线性变换,那论文内容可能就不成立了。

数学原理

这个线性变换将一个 4096 维的向量映射到了一个 32000 维的向量,从线性代数的角度来看,这是一个低维向高维的映射,所以它肯定不是一个满射(onto mapping)。也就是说,这个映射的像空间(image)只是 32000 维实数空间的一个子空间(subspace),而且这个像空间的秩(rank)最多是 4096。

这意味着可以找到不多于 4096 个线性无关的基向量(basis),使得这个像空间的每一个元素都能表示为这些基向量的线性组合。假设能采集到 4096 个线性无关的输出 logits,那这些 logits 就构成了像空间的一组基向量。

反过来想,如果你不知道 LLM 的 hidden size,那么你可以通过采集足够多的输出 logits,以保证有足够多的线性无关的向量。然后对矩阵进行奇异值分解(singular value decomposition),可以通过非 0 的奇异值个数推导出矩阵的秩。这个秩应该接近于模型的 hidden size。

逆向恢复 logits

遗憾的是,很多模型的 API 并没有输出完整的 logits 矩阵,但幸运的是,OpenAI 的 API 支持输出最多 top 5 个 token 的 logprobs,并且支持 logit_bias 干预 token 的输出。那就给了反复通过 API 调用来逆向恢复 logits 向量的可能。

但是具体方法我没看,粗读嘛,知道能做到就行了,有用到的时候再看吧。还有另一篇文章《Stealing Part of a Production Language Model》分析了在没有 logit_bias 甚至没有 logprobs 时该如何恢复 logits,我也没看,记录下链接 https://arxiv.org/pdf/2403.06634.pdf

无法输出的 Token

这篇论文还介绍了很多其它应用,太长没有看。比较有意思的一个引用是,在将嵌入向量映射到 logits 的过程中,如果一个 token 的嵌入向量在其它 token 的嵌入向量组成的凸包的内部,它就永远不可能被输出。扫了一眼引用的论文,证明没看懂,大致意思是 softmax 权重矩阵的低秩特性导致了可能输出 token 的排列在线性变换后不会出现在子空间里?实话说我感觉不像是很严谨的数学证明。。。

在 LLM 时代我们是否还需要倒排索引?

近些年,EBR(基于文本嵌入向量的召回)的强大语义召回能力让基于属性索引的传统倒排索引技术黯然失色,即使对专业搜索引擎来说,EBR 的应用也是越来越广泛 [1,2,3] 。尤其在 LLM(大语言模型)激起的 RAG(检索增强生成)技术框架下,大部分人似乎已经忘记了倒排索引,向量数据库成为 RAG 的标配。

但在 LLM 时代倒排索引真的没有用武之地了吗?我尝试列一下自己的思考。

  • Embedding 向量缺少 ground truth,但是倒排有。你无法通过直接观察或者测量,来明确一个向量指代的一段文本是否一定包含某些信息。但是如果这段文本在某个 term 的倒排拉链里,你可以从一定程度上明确这段文本包含了一些相关信息。
  • term 命中也是一种 attention。在训练模型时,我们总是希望 LLM 能关注到 context 中应该关注的信息,而忽略其它无关内容。那跟用户问题或者指令中某些 term 相关的内容,应该需要更多的关注。其实也可以类比人类肉眼的信息查找,人们也总是会扫到某些关键词,然后再仔细阅读它的上下文篇章。
  • 不基于倒排做召回,仍可以用倒排做粗筛。倒排作为一种可以查询 term 命中的高效结构,对 EBR 也许可以起到补充作用。例如对于某些 EBR 效果不够理想,误召回概率较高的场景下,对得分比较低的文档用命中信息作一次粗筛,能显著提升送给模型的 context 质量,也能减少对 LLM 计算资源的浪费。
  • term 命中的位置信息和权重不再重要。对于 LLM 来说,它会自行关注到 context 中需要关注的信息,不再需要位置信息或者权重来指示文本中哪些部分更重要。也就是说,倒排只需要解答 term 在文本中是否出现的问题,而不需要回答出现几次、在哪里出现的问题。
  • 也许倒排不再用 term,而是 token。term 依赖于切词,有一定的语义含义,term 集合空间一般有百万甚至千万的量级。但现在 LLM 大部分使用 BPE(Byte-Pair Encoding)分词器,token 集合空间只有几万到十几万的量级。使用 token 将显著减少倒排链的数量而优化其性能,但 token 存在没有归一化、分词边界不对齐的问题,是否可用还有待验证。

参考

[1] Guo, Ruiqi, et al. "Accelerating large-scale inference with anisotropic vector quantization." International Conference on Machine Learning. PMLR, 2020

[2] Huang, Jui-Ting, et al. "Embedding-based retrieval in facebook search." Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020.

[3] Liu, Yiding, et al. "Pre-trained language model for web-scale retrieval in baidu search." Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining. 2021.