在 32 张 L40S/L20 上运行 DeepSeek-R1/V3 原版 FP8 模型

上文讲到,FP8 模型之所以无法 TP32 运行,主要因为 DeepSeek R1/V3 模型保存的参数是 FP8 128x128 量化的。Attention 还好,128 个头做 TP16 或者 TP32 都没问题,问题主要出在专家的计算上。

前三层 MLP 的 intermediate_size 是 18432,18432 做 TP16 是 1152,相当于 9 个 128。但如果做了 TP32,就是 4.5 个 128,这会导致参数无法切分在 128 边界上,无法支持按 128 block 量化。路由专家也类似,moe_intermediate_size 做 TP32 相当于 0.5 个 128。

缩小 DeepSeek-R1/V3 的量化块到 64x64

如果想支持 TP32,一个很显然的路径就是把量化方式改成按 64x64 分块量化。本来这是一个比较复杂的操作,但我想了一个取巧的办法:直接把 128x128 的缩放系数,复制到 4 份。

为了方便理解这个方案,我画了一张图。假设我们有一个 4x8 的 INT32 矩阵,按照 4x4 block 量化到 INT8,它会分成 1x2 个 4x4 的块,每块一个缩放系数,那就是 1x2 个缩放系数。如下图所示,第一个块的缩放系数是 7.3465,它是通过第一个块里的最大绝对值 |-933|/127 得到的,同理第二个缩放系数来自 974/127。

那如果我想将它的量化 block 缩小到 2x2,理论上我应该计算每个 2x2 block 的最大绝对值,然后 /127 得到缩放系数,这样精度损失最小。可是我嫌麻烦,偷个懒,我直接把 4x4 的缩放系数复制 4 份,虽然精度有损失,但好处是在计算上与 4x4 的量化结果完全一致。换句话说,就是原汁原味,纯血参数。

将这个逻辑迁移到 DeepSeek FP8 量化的 128x128 block 缩小到 64x64,原理是一样的,也是将 scale 参数矩阵进行 2x2 等值扩充。通过非常简单的参数处理,就能够实现将 DeepSeek 原始模型转成 64x64 的分块量化,然后就可以用 SGLang 加载运行了。

运行方法

我们以昨天发布的 DeepSeek-V3-0324 为例,逐步说明如何使用这种方法在 L40S 和 L20 上运行 FP8 满血+纯血版的 DeepSeek-V3-0324,不需要等待美团再发布 INT8 版本。

假设你已经下载好了模型,在 /workspace/DeepSeek-V3-0324/。那你需要先下载我的开发分支,并通过源代码安装它(如果遇到困难,建议你在 SGLang 的开发 Docker 中执行它):

git clone -b l40s-dsfp8 https://github.com/solrex/sglang.git
cd sglang
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python

然后用下面这个脚本,将 128x128 量化的 DeepSeek-V3-0324,转换到 64x64 量化的 DeepSeek-V3-0324-Block64x64:

python3 scripts/resize_block_size.py /workspace/DeepSeek-V3-0324/

当你在 4 台(或 8 台)机器上都完成了 SGLang 安装和参数拷贝后,就可以用下面的命令来启动 SGLang 服务了。注意替换 MASTER_IP 和 TCP_IFACE 到正确的值。

# MASTER_IP: 主节点 IP
# TCP_IFACE: 主网卡接口名,可通过 ifconfig 获取

# 主节点
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP_IFACE GLOO_SOCKET_IFNAME=TCP_IFACE python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32 --host 0.0.0.0 --port 8000

# 从节点 1
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP_IFACE GLOO_SOCKET_IFNAME=TCP_IFACE python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32

# 从节点 2
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP_IFACE GLOO_SOCKET_IFNAME=TCP_IFACE python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32

# 从节点 3
NCCL_DEBUG=INFO NCCL_IB_GID_INDEX=3 NCCL_SOCKET_IFNAME=TCP网卡 GLOO_SOCKET_IFNAME=TCP网卡 python3 -m sglang.launch_server --model /workspace/DeepSeek-V3-0324-Block64x64/ --tp 32 --dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote --enable-torch-compile --torch-compile-max-bs 32 --cuda-graph-max-bs 32

性能

在 MASTER 节点使用下面的命令进行性能测试(需要先下载测试数据集 ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json),输入固定 200,输出固定 200,并发 128,测试两轮。

python3 -m sglang.bench_serving --backend sglang-oai --dataset-path /workspace/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-name random --random-range-ratio 1 --random-input-len 200 --random-output-len 200  --request-rate 128 --num-prompt 256 --max-concurrency 128 --host localhost --port 8000

我测试的性能指标是:

=========== Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 108.13
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 50917
Request throughput (req/s): 2.37
Input token throughput (tok/s): 473.49
Output token throughput (tok/s): 473.49
Total token throughput (tok/s): 946.98
Concurrency: 127.38
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 53802.80
Median E2E Latency (ms): 53530.89
---------------Time to First Token----------------
Mean TTFT (ms): 7768.52
Median TTFT (ms): 7769.45
P99 TTFT (ms): 12003.95
---------------Inter-Token Latency----------------
Mean ITL (ms): 232.35
Median ITL (ms): 220.48
P95 ITL (ms): 242.17
P99 ITL (ms): 270.50
Max ITL (ms): 10250.34
==================================================

对照上一篇博客《刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据》,看起来 FP8 Block 量化的性能比 INT8 Channel 量化的性能要差一些。

代码改动

这次的代码改动不大,主要是参数转换脚本,和一些针对 64x64 的 tunning,以及 warning 的修复。整理好后我会提个 PR 给 SGLang,但这次我不确定这个 PR 是否会被接受,感兴趣的同学可以直接看这个 commit:https://github.com/solrex/sglang/commit/03d34078d8d65983aabc0386391743cc43f535ed or https://github.com/sgl-project/sglang/pull/4860

地震

正当我写到这的时候,忽然手机通知地震了。生平第一次收到,记录一下。我完全没震感,但是有朋友感觉到了。

刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据

前一篇博客《使 SGLang 支持在 32 张 L40S 上运行 DeepSeek-R1》中提到我那非常特殊的 L40S 显卡配置,结果发现是个大乌龙。

首先,这台机器有 PCIE 4.0 Switch,每张 Switch 上插了 4 张 L40S 显卡。我误会的 PCI-to-PCI 应该是挂了一个给监控屏用的小显卡。

其次,有一张网卡插错位置了。本来应该每 4 张显卡配 1 张双口网卡,其中一张没插到对应的 Switch 上。

最后,3 张网卡 6 个网口,只启用了 1 个网口,有 5 个网口没启用。

所以,实际上所有机内通信走的都是 PCIE,所有跨机通信走的都是主网卡,这……只能怪自己没经验,默认以为交过来的环境都是对的。

折腾了几天,总算搞对了,正确的拓扑如下:

主网卡两个网口做了链路聚合,用作 TCP 通信;PCIE Switch 上的 4 个网口专用作 RDMA 通信。同机两个 NUMA 域双卡之间通信,走 PCIE 大概 11GB/s,走 GDRDMA 能提升到 19GB/s。最关键的是多机通信,从单网口的小水管提升到了 4 网口 GDRDMA。

采用与上篇文章同样不严谨的测试方式:

# 128 并发
[TP0] Decode batch. #running-req: 128, #token: 86816, token usage: 0.44, gen throughput (token/s): 777.92, #queue-req: 0, # 之前 565.73, 1.37x
# 32 并发
[TP0] Decode batch. #running-req: 32, #token: 8923, token usage: 0.04, gen throughput (token/s): 457.66, #queue-req: 0, # 之前 260.73,1.75x
# 4 并发
[TP0] Decode batch. #running-req: 4, #token: 1437, token usage: 0.01, gen throughput (token/s): 153.98, #queue-req: 0, # 之前 42.3,3.64x
# 1 并发
[TP0] Decode batch. #running-req: 1, #token: 482, token usage: 0.00, gen throughput (token/s): 49.02, #queue-req: 0, # 之前 26,1.88x

固定 200 输入,200 输出,128 并发,2 轮请求:

============ Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 87.57
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 51023
Request throughput (req/s): 2.92
Input token throughput (tok/s): 584.64
Output token throughput (tok/s): 584.64 # 之前 391.47,1.49x
Total token throughput (tok/s): 1169.28
Concurrency: 127.23
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 43524.37
Median E2E Latency (ms): 43246.71
---------------Time to First Token----------------
Mean TTFT (ms): 7225.06
Median TTFT (ms): 6769.52
P99 TTFT (ms): 14229.37
---------------Inter-Token Latency----------------
Mean ITL (ms): 182.87
Median ITL (ms): 162.95
P95 ITL (ms): 166.60
P99 ITL (ms): 202.39
Max ITL (ms): 13711.39
==================================================

可以看到,修复网络本身存在的问题以后,推理性能提升还是很显著的。

one more thing

前一篇文章提到: L40S/L20 虽然支持 FP8 精度,却不能运行 FP8 的 DeepSeek-V3/R1。这个问题搞定了,我已经实现在 L40S 上运行原始 FP8 参数的 DeepSeek-R1,满血+纯血。等我整理一下代码,下篇博客来介绍一下。

使 SGLang 支持在 32 张 L40S/L20 上运行 DeepSeek-R1

我提交的 PR: Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. #4418 [1] 已经合入到了 SGLang 的主干,也许这是第一个用 PCIE 互联的 GPU 小卡跑通 DeepSeek-R1 推理的例子。

有一些遇到的问题分享一下,在适配别的 GPU 时可以用来参考。

背景

因为 128x128 block 量化的 DeepSeek-R1/V3,参数维度在除 32 时,会遇到商无法被 128 整除的问题,所以即使 L40S 支持 FP8,也无法直接用 TP32 运行 DeepSeek-R1/V3。

感谢美团提供了 channel 量化的 DeepSeek-R1 参数 DeepSeek-R1-Channel-INT8 [2],并且在 SGLang 代码库做了适配。这让我可以在 48G 显存的 L40S 上尝试一下运行满血版(int8 量化) DeepSeek-R1,但尝试过程没有我预期的顺利,遇到了不少问题。

问题

一、shared memory OutOfResources

  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/triton_ops/extend_attention.py", line 356, in extend_attention_fwd
_fwd_kernel[grid](
...
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 374, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 102400, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

服务启动正常,收到推理请求后,Attention kernel 报 shared memory 资源不足。我刚开始没理解这个问题,后来看了下代码,发现 triton attention 代码中对不同类型的 GPU 架构设置了不同的 block size。L40S 属于 SM89 架构,但 SM89 都归类到了 SM80 架构里。

查了一下 CUDA 编程手册,SM89 的 shared memory 大小是 100K,但 SM80 是 160K。我猜测大概率就是这个问题,所以给 SM89 单独加了个分支,缩小了 SM89 的 block size,解决了这个问题。

后来发现 sglang/test/srt/test_triton_attention_kernels.py 可以完美复现这个问题,而我却傻乎乎地每次重启整个服务去测试正确性。

二、gemm executioin failed RuntimeError

  File "/usr/local/lib/python3.10/dist-packages/sgl_kernel/ops/__init__.py", line 118, in int8_scaled_mm
return torch.ops.sgl_kernels.int8_scaled_mm(
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: gemm executioin failed, error: Error Internal

收到推理请求后,有很大的概率触发这个错误,但也有小概率能完成一次推理。所以我没想到是算子问题,我以为是显存不足。调了半天各种显存占用的参数,后来没办法了才回过头来看实际的算子调用。

又是跟上面类似的问题,sgl_kernel 中自定义的 int8 gemm 算子将 SM89 归类到 SM80 进行矩阵计算的 dispatch。这显然会遇到与上面类似的问题,但是我又不知道 SM89 该怎么进行 dispatch,看起来需要做很多 benchmark 或者计算才能确定。

于是我就去翻 TensorRT-LLM 和 vLLM,让我给翻到了 vLLM 的实现,我就照着 vLLM 对 SM89 的 dispatch 逻辑抄了一遍。这次我学乖了,先看看有没有 test。跑通了 sglang/sgl-kernel/tests/test_int8_gemm.py,才去进行集成测试。

三、sub-optimal MoE

Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json

这是一个次要问题,看起来是一个专门的配置没有找到。后来我研究了一下,应该是 triton 版本的 fused_moe 需要读取一个在每种类型的 GPU 上都 benchmark 过的最好配置来运行。

python  benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model /workspace/DeepSeek-R1-Channel-INT8 --tp-size 32 --dtype int8_w8a8 --tune

我按照文档的要求,在 L40S 上跑了一下 bench,把最终输出的文件拷贝到对应的位置就好了。这个 bench 真的要跑好久,大概两三个小时。

性能

我这 4 台 L40S 的硬件配置有些特殊:它是单机 8 卡,PCIE 4.0 连接到主机,但它既不是 2-2-4,也不是 2-2-8。

20250322:其实是硬件配置错误,详见:《刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据

其中 4 张卡直插 PCIE 上,另外 4 张卡通过一个 PCI-to-PCI Bridge 插到 PCIE 上。为了弥补这样连接的带宽缺陷,PCI-to-PCI Bridge 上还接了 2 个 100Gb/s 的 RDMA 网卡。主机上每个 NUMA 域,分别也有 2 个 100Gb/s 的 RDMA 网卡。

所以这 8 张显卡,6 张网卡,拓扑如下所示:

我也实在算不出来这玩意儿的互联带宽。逻辑上来说,这大概相当于用 4 PCIE 4.0 GPU + 2 RDMA 网卡的性能,所以也许这个 4 x 8卡,实际上相当于 8 x 4卡。以下性能评测结果供参考。

下面是加载时的显存使用情况:

[TP0] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=22.43 GB, mem usage=21.27 GB.
[TP0] Memory pool end. avail mem=7.95 GB
[TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=7.92 GB
[TP0] Capture cuda graph end. Time elapsed: 411.41 s. avail mem=5.89 GB. mem usage=2.02 GB.
[TP0] max_total_num_tokens=201723, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=2049, context_len=163840

下面是使用固定 200 token 输入时,不同并发下的净 decode 速度。因为 bench 速度太慢,我也懒得等最终的结果,所以这里直接节取了 log。

# 128 并发
[TP0] Decode batch. #running-req: 128, #token: 86816, token usage: 0.43, gen throughput (token/s): 565.73, #queue-req: 0,
# 32 并发
[TP0] Decode batch. #running-req: 32, #token: 8923, token usage: 0.04, gen throughput (token/s): 260.73, #queue-req: 0,
# 4 并发
[TP0] Decode batch. #running-req: 4, #token: 1439, token usage: 0.01, gen throughput (token/s): 42.30, #queue-req: 0,
# 1 并发
[TP0] Decode batch. #running-req: 1, #token: 482, token usage: 0.00, gen throughput (token/s): 26.00, #queue-req: 0,

这是固定 200 输入,200 输出,128 并发,2 轮请求的完整压测结果:

============ Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 130.79
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 50992
Request throughput (req/s): 1.96
Input token throughput (tok/s): 391.47
Output token throughput (tok/s): 391.47
Total token throughput (tok/s): 782.94
Concurrency: 127.49
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 65135.99
Median E2E Latency (ms): 64974.27
---------------Time to First Token----------------
Mean TTFT (ms): 17554.19
Median TTFT (ms): 19216.02
P99 TTFT (ms): 21662.98
---------------Inter-Token Latency----------------
Mean ITL (ms): 239.84
Median ITL (ms): 220.95
P95 ITL (ms): 233.20
P99 ITL (ms): 299.49
Max ITL (ms): 16077.21
==================================================

链接

[1] https://github.com/sgl-project/sglang/pull/4418

[2] https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8

[3] https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh

理解 FlashMLA 在 DeepSeek MLA 计算过程中的位置和作用

读过 DeepSeek V2 Paper 的人可能对下面这张图印象非常深刻,非常直观地揭示了 MLA 算法的基本原理。但结合这张图或者 MLA 公式去看 DeepSeek 在开源周发布的 FlashMLA [1],可能就一脸懵逼了。

图1: MLA 与其它 Attention 对比 [2]

理解 FlashMLA 代码库解决的是怎样一个问题,还需要更进一步的理解 MLA 计算的优化过程。

DeepSeek MLA 公式

下面这张图是 DeepSeek V2 Paper 中的 MLA 公式,相信读到这里的人都知道这是个啥。所以我不做解释,放这里主要是为了方便与下面的图进行交叉阅读。

图2: MLA 计算公式 [2]

MLA Naive 实现

最直接的实现 MLA 的方式,就是实现图 2 MLA 的所有计算过程,这个计算过程也与图 1 完全一致。我以 DeepSeek V3 的参数为例,做了下面这张图。

为简化复杂度,这张图里隐藏了两个维度,batch size 和 seq len,或者你可以把它简单地理解成,仅输入一个 token 给模型进行计算的状态。就像我在前面的博客《DeepSeek-V3 MTP 工程实现思考》中做的那样:你仅输入“how”这一个单词给模型,看看模型能给你生成什么。

图3: MLA Naive 实现,转化为 MHA 计算

每个绿色的方框代表一次矩阵乘/线性变换,框内是参数矩阵的名字和维度;每个白色的方框代表一个中间结果 tensor,框内是张量名字和维度;黄色的方框则代表核心的 Attention 计算,也就是图 2 中的公式 (46)。参数矩阵和中间结果 tensor 的名字与图 2 保持一致。

在 Naive 实现中,512 维的 Latent KV cKV 被映射回对应 128 个 head,每个 head 128 维的 K kC 和 V vC,然后再拼接上位置向量 kR最终形成标准的 q、k、v,输入到标准的 Multi Head Attention 进行 Attetion 计算。与其他常见模型中 MHA 的唯一不同,可能是 head dim 192 不是 2 的 n 次方。

Naive 实现最直观,但它的问题是在 Decode 计算时性能不够好。Decode 计算时,输入的 Q 往往只有一个 token,但需要用到所有前缀 token 的 k 和 v,也就是我们通常说的 KV Cache。Naive 实现有两种选择:

① 缓存 Latent KV。缓存规模小,但 Latent KV 缓存不能直接送 MHA 计算,还得经过 WUK 和 WUV 的线性映射,可以看到这是两个规模不小的矩阵计算,而且每轮都得重复计算。

② 缓存 KV。缓存规模大,不用重复计算,性能好。但 MLA 的一大好处就是 KV Cache 压缩,这样显存内能缓存更多 token,支持更大的 batch 和 prefix cache。如果缓存 KV,在显存上对比 MHA 就完全没有优势了。

所以,Naive 实现可能会用于 Prefill,但在 Decode 计算时需要更好的计算方法

MLA 优化实现

很多人把下面这种 MLA 的优化称为矩阵吸收[3],来源是 DeepSeek V2 里面这样说:

Fortunately, due to the associative law of matrix multiplication, we can absorb WUK into WUQ, and WUV into WO. Therefore, we do not need to compute keys and values out for each query. Through this optimization, we avoid the computational overhead for recomputing kCt and vCt during inference.

但我更喜欢把它理解成矩阵乘法交换律。因为实际上大家发现,提前将两个参数矩阵乘起来,即把 (WUQ)TWUK 的计算结果做为新的参数矩阵,在性能上还不如分开计算[3]。既然实际计算过程是交换矩阵计算过程,从“矩阵吸收”角度思考反而更绕了。

图4: MLA 优化实现,转化为 MQA 计算

上图中的两个虚线箭头,显示了在优化的计算过程中,哪些参数矩阵被交换了位置。它们能交换的原因,就是从数学上这样修改是等价的(矩阵乘法交换律)。

与图 3 相比,可以看到输入给 Attention 的 q、k、v 形状发生了明显的变化。q 的形状由 128x192 变化成了 128x576,k 的形状由 128x192 变化成了 576,v 的形状由 128x128 变化成了 512。这样一来,原来的 KV 就不存在了,新的计算过程中只剩下 Latent KV 了。而且实际上 V 也不存在了,因为 V 就是 K 的前 512 维。

再回看图 1,你会发现,这不就是 MQA 么?而这就是实际上 FlashMLA 代码库解决的问题:提供了一个专门为 q k head dim 为 576,v head dim 为 512,v 与 k 的前 512 维重叠,q head 数不超过 128(TP 下会变少)设计,仅支持 H800/H100 GPU 的 MQA 优化算子。

简单来说:虽然这个库叫做 FlashMLA,但它提供的 flash_mla_with_kvcache() 是个 MQA 算子,只不过这个 MQA 的输入形状有些特殊。

小知识

为什么会这样呢?因为开源软件约定俗成的 Attention 算子封装,仅仅指图 2 中公式(46)这一行,是不包含前后的线性变换的。开源推理框架允许用户通过配置选择不同的 Attention 算子实现,比如 FlashAttention、FlashInfer、Triton 实现等。

虽然 MLA 算法的核心在前后的线性变换,FlashMLA 算子却不能提供这些变换。这些线性变换只能被实现在模型建模的 modeling 代码的 MLA 模块中,比如 SGLang 代码库 python/sglang/srt/models/deepseek_v2.py 文件中的 DeepseekV2AttentionMLA [4] Module。

引用

[1] https://github.com/deepseek-ai/FlashMLA

[2] DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model, https://arxiv.org/pdf/2405.04434v5

[3] DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子, https://zhuanlan.zhihu.com/p/700214123

[4] https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_v2.py

2 行代码校验大模型(如DeepSeek-R1)权重文件下载完整性

很多人在 DeepSeek-V3/R1 爆火之后,都希望体验本地运行“满血版”模型。但是满血版模型的权重参数文件有 600 多个 G,光权重文件就拆成了 163 个。

当你受不了 HuggingFace 官网的下载速度,用其它方法或者渠道获得了权重文件后,怎么确认这些权重文件是完整无损坏的呢?

这里介绍一个最简单的方法,仅需要 2 行代码。

环境

前提 1,你已经 clone 了不含权重文件的模型 git 仓库。以 DeepSeek-R1 为例,通过下面命令可以仅 clone 代码文件到 DeepSeek-R1 目录下:

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1

前提 2,你已经用某种方法下载好了权重文件。请将这些权重文件放到已 clone 的 git 仓库目录内,以 DeepSeek-R1 为例,就是将 163 个 *.safetensors 文件移动到 DeepSeek-R1 目录下。

你也可以不移动权重文件,那么你就需要在执行第 2 行命令前将 checksum 文件移动到权重文件所在目录。

第 1 行代码

获得所有官方权重文件的 sha256 checksum,并保存成一个标准的 checksum 文件。这行代码需要在 git 仓库目录下执行

git lfs ls-files -l | awk '{print $1"  "$3}' > large_files.sha256

这行命令输出的文件内容形如:

c2388e6b127ce6664e35c5e2529c3ce4bfc99f4f7fb6fa48e92b29ed5e4922af  model-00001-of-000163.safetensors
5f450c75da7eb897b74a092eee65df8bb115fce81cccd2bbaeb220bd97197875 model-00002-of-000163.safetensors
...
913177d9e0dfb228769e0a13a386c34b919dcbb32a430ce230979f53bf7ae5bc model-00163-of-000163.safetensors

第 2 行代码

根据官方权重文件的 checksum,检查本地文件的完整性。这个检查的执行速度会非常慢,因为它需要为每个文件计算 sha256sum,然后再与 checksum 文件做比对。

sha256sum -c large_files.sha256

这行命令的输出形如:

model-00001-of-000163.safetensors: OK
model-00002-of-000163.safetensors: FAILED
...
model-00163-of-000163.safetensors: OK

如果所有行的输出都是 OK,那么恭喜你,所有权重文件都没有损坏;如果有某行输出为 FAILED,就代表该文件没有通过完整性校验,你需要重新下载它。

此方法对所有标记为 LFS 的文件均有效,并不仅限于 *.safetensors 文件,比如量化模型 *gguf 权重文件,也可以同样用此方法校验。

单机 KTransformers 运行 DeepSeek-R1-GGUF 4 bit 量化模型 Q4_K_M 实测

最近有些文章把 KTransformers 吹得没边儿,但是看到的实测案例非常少。我也比较好奇它的实际表现,所以来实测一下看看。

机器配置

硬件实测环境官方案例
CPUIntel(R) Xeon(R) Platinum 8350C CPU @ 2.60GHz, 单插槽 32 核,64 超线程,2 插槽,2 NUMA 节点Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes
内存64GB DDR4 2933MHz x 24,共 1.5 TB 内存standard DDR5-4800 server DRAM (1 TB), each socket with 8×DDR5-4800
GPUNvidia L40S, 48GB VRAM4090D 24G VRAM

实测环境机器配置看起来很强悍,但距离 KTransformer 首页给的官方案例配置还是有差距:

  1. 8350C 是第 3 代至强 CPU,官方案例用的 6454S 是第 4 代至强 CPU,Intel AMX 指令集只在第 4 代和第 5 代至强上支持,号称比前一代有 3~10 倍的推理性能提升;
  2. DDR4 2933 的访存带宽,跟官方案例用的 DDR5 4800,纸面数据差 60%;
  3. 虽说 L40S 比官方案例用的 4090 性能要更强,显存要更大,但目前 KTransformers 给的配置并不能完全发挥出来。

程序环境

基于 Pytorch 镜像 pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel ,本地编译 KTransformers 命令:make dev_install 。

注:通过 pip install 的 KTransformers 包会在运行时 crash 在 cpuinfer.so 里,我猜测是官方的包使用了更高级的指令集,而我这台机器不支持。

执行命令:numactl -N 1 -m 1 python ./ktransformers/local_chat.py --force_think --model_path DeepSeek-R1/ --optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml --gguf_path DeepSeek-R1-GGUF/DeepSeek-R1-Q4_K_M/ --cpu_infer 33 --max_new_tokens 1000

参数说明

--force_think: 强制在开头输出 <think> 标签

--model_path: 来自于 git clone https://huggingface.co/deepseek-ai/DeepSeek-V3,只需要下载代码,不需要下载参数文件

--gguf_path: 下载的量化参数,只需要下载子目录:https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M

--cpu_infer: 粗看代码,KTransformers 是一个线程分发任务,其余线程消费计算任务,所以官方案例这里总是配置的 2 的 N 次方 + 1。但是在实测时,是否 2 的 N 次方 + 1 区别不大。

实测结果

Prompt: 【《背影》全文不含标题】 请简明扼要地回答问题:这篇文章出自哪里?作者是谁?写于什么年代?

图1:生成结果和性能

可以看到性能,生成整篇答案花费了 4 分多钟,prefill 性能 11 toks/s,decode 性能 3.4 toks/s。这个比官方案例的性能慢了一倍多。而且这个回答太啰嗦了,没有遵循 prompt。

图2: CPU 和 GPU 利用率

图 2 是通过 top 和 nvidia-smi 查看的运行时的 CPU 和 GPU 利用率:可以看到内存占用 380 GB,CPU 33 核差不多用满;GPU 占用 13GB 显存,利用率大概 9%。

图3:同样 prompt DeepSeek Chat 回答

上图是 DeepSeek-R1 chat 官网的回答,看到比量化版好很多,至少遵循了“简明扼要地回答问题”的指令。

其它尝试

为了测出来最大性能,我尝试过加大 --cpu_infer 参数到 65、129,尝试过切换 optimize_rules 到 DeepSeek-V3-Chat-multi-gpu-4.yaml 或者 DeepSeek-V3-Chat-multi-gpu-8.yaml,实测都没有看到性能优化,很多甚至还劣化到 0.8 toks/s。

但是降低 --cpu_infer 参数到 25,没有观察到性能劣化。

观察

从 GPU/CPU 的利用率可以看出,KTransformers 主要靠 CPU 推理,AMX 指令集和内存访问速度很关键,GPU 利用率很低,反而不关键。

DeepSeek-R1-Q4_K_M 4 bit 量化模型较非量化模型效果有显著差距,可以观察到指令遵循都不太够。

KTransformers 目前对计算任务的拆分,并没有实现跟随 CPU 核数线性提升性能,这说明也许里面还有很多优化可以做。

讨论

现在大模型动辄几百 GB,需要 N 张显卡才能运行,客观上阻碍了很多感兴趣的人去体验和创新。

KTransformer 能用低配硬件慢速跑起来(接近)满血的模型,是非常赞的一个项目。它后续的持续优化也值得期待。

但 CPU 推理在成本上有它的局限性,大公司也不傻,如果 CPU 成本低,为啥要买那么多 GPU 卡?

核算推理成本要尊重客观事实,别动不动“告别天价显卡”。进行压力测试后,将合理计算的服务器成本平摊到每个请求/token上来计算。从这个角度看,大概有些人又要鼓吹“告别天价 CPU 服务器”了。

DeepSeek-V3 MTP 工程实现思考

一个东西从 idea 到实现,中间有着巨大的鸿沟,DeepSeek-V3 的 Multi-Token Prediction 也一样。虽然开源社区很多在一个多月前已经支持了基于 TP 的 DeepSeek-V3 推理,但是 MTP 部分目前都还在开发中,进展最快的可能是 vLLM,参见 vLLM PR #12755 [5] 。

但我觉得这并不是一个终点,可能还有很多工程优化工作需要继续完成。下面我尽量用浅显的图表和语言来说明我的理解和思考,如有错误也欢迎指出。

Speculative Decoding (投机解码)

理解 MTP 首先要理解 Speculative Decoding,这里不过多介绍,仅用一张图说明 Speculative Decoding 的计算过程,便于理解后续的分析。如果希望深入了解可以观看 这个 Youtube 视频 [1]。

图1:自回归和投机解码示例,来自视频:EAGLE and EAGLE-2 [1]

左边展示的是常规的 LLM 自回归迭代计算过程。

初始 prompt 是 token: "how",how 先经过 embedding 计算变成 ehow,然后经过 Decoder Layer 的计算,输出特征向量 fhow(最后一层 hidden states),经过 LM Head 转换成一个概率分布向量 phow,通过采样得到生成结果 token:"can"。

然后 LLM 会把 can 作为新的输入,进行下一步的计算,直到 LLM 给出推理结束 token:"<EOS>"。

自回归解码是逐 token 迭代进行的,生成一个 token,将 token 作为输入生成新的 token,直到遇到结束条件:生成 "<EOS>" 或者达到最大生成长度。

右边展示的是 Speculative Decoding 的过程。

初始 prompt 还是 how,但是通过其它方式(比如一个小模型,叫做草稿模型)先推测了两个草稿 token:"can、we",同时输入到目标模型。普通的 Decoder 实现仅能解码 1 个 token,这里改造成能够同时解码输出 3 个 token 的 hidden states。这样我们就能同时得到:phow, pcan 和 pwe。然后就可以跟草稿模型输出的 qhow 和 qcan 进行比较,验证是否接受草稿模型的草稿 token:can 和 we。

图上目标模型验证结果是接受 can,但是拒绝 we,那就使用 pcan 采样,得到生成结果 token:I。这就意味着,投机解码通过一次推理,得到了两个 token:can、I实现了1倍逻辑加速:can 是推测以后得到验证的 token,I 是拒绝推测 we 以后,根据目标模型自身输出采样的 token。

EAGLE 与 DeepSeek-V3 MTP

EAGLE 简单说来是作者认为通过目标模型本身的特征向量(就是上面的 fhow)预测下一个 token 更准确,所以草稿模型使用了与目标模型基本相同的结构,利用了目标模型输出的特征向量(fhow)作为草稿模型输入。如果希望深入了解可以观看 这个 Youtube 视频 [1]。

图2:EAGLE 和 DeepSeek-V3 MTP 的区别 [2][3]

MTP 与 EAGLE 不同的点如上图所示,除了多做了一次 Norm 这种细节之外,主要是多步推理的时候的串行。EAGLE 在多步推理时,只使用到了一个草稿模型做自回归推理;MTP 在多步推理时,其实是多个草稿模型进行串行推理

了解完上面这些背景以后,我们可以分析如果希望实现 DeepSeek-V3 MTP,都需要做哪些工作。

MTP 实现

1. MTP 加载

虽然很多框架都支持了 EAGLE,但一般的实现,都只支持 1 个草稿模型。而 MTP 从设计上,需要加载多个草稿模型,每一个 MTP 层,都是一个草稿模型。

在推理的时候,要根据不同的 step,选不同的模型进行推理。这就使得 MTP 草稿模型的加载和推理的调度比其它投机编码要复杂。

但如果 MTP 的步长等于 1,那就相当于 1 个草稿模型,实现会简单很多。

2. MTP Prefill

图3:DeepSeek-V3 MTP [3]

从上图可以看出,第 i 个 MTP Module 的输入 token,是第 i+1 个 token 到第 n 个 token,n 是当前生成的总长度。而它不仅需要 token 的 embedding,还需要 token 在前一个模型计算得到的 hidden states。

比如 MTP Module 1 的输入,是 token 2 到 5 的 embedding 和 main model 最后一层输出的 token 2 到 5 的 hidden states。

这也就意味着,在完成 DeepSeek-V3 的 prefill 时,需要输出最后一层的 hidden states,才能进行第 1 个 MTP 的 prefill;第一个 MTP 输出最后一层的 hidden states,才能进行第 2 个 MTP 的 prefill,以此类推。

可以注意到:多个 MTP 的多次 prefill 计算是串行的。这意味着每增加 1 个 MTP Module,每次推理的时候就要多一轮串行的 prefill,并且多一份 kv cache。一个主模型加 N 个小模型的推理,可能会严重影响计算调度的效率,可能这也是为什么 DeepSeek-V3 只输出了 1 个 MTP Module 的原因。大概他们也认为,仅使用 1 个 MTP Module 性价比最高

3.MTP PD 分离

我在之前一篇博客[4]中列举了 PD 分离背后面临的很多架构选择,MTP 会让 PD 分离变得更复杂。框架有两种选择:

选择一:Prefill 节点做 MTP Prefill:如下图所示,P 节点做完 DeepSeek-V3 Prefill 以后,保留最后一层所有 token(除了第 1 个,即index 0)的 hidden states,采样生成的第一个 token,获得 tokenid,然后将这些输入到 MTP Module 1 做 Prefill。最后将 1) DeepSeek-V3 61 层的 KV Cache; 2) DeepSeek-V3 MTP 的 KV Cache; 3) DeepSeek-V3 生成的第一个 tokenid;4) DeepSeek-V3 MTP 生成的第一个草稿 tokenid 和概率;这 4 部分传给 D 节点。

图4:DeepSeek-V3 MTP Prefill PD 分离计算方案

选择二:Prefill 节点不做 MTP Prefill:P 节点做完 DeepSeek-V3 Prefill 以后,把:1) DeepSeek-V3 61 层的 KV Cache; 2) 最后一层所有 token(除了第 1 个,即index 0)的 hidden states;3)所有 token (除了第 1 个,即 index 0)的 embedding。这 3 部分传给 D 节点。D 节点将生成第一个 token 的 hidden states 经过 LM Head 计算和采样获得 tokenid,然后对 MTP 进行 Prefill。

考虑到通信量和复杂度,大概大家都会选择一,但这样 Prefill 节点就必须加载 LM Head 了,因为 MTP 依赖生成的 tokenid 做 embedding 输入。

4. MTP MLA 算子优化

由于 MLA 的复杂性,现在的很多 MLA 实现并不支持在 decode 单次前向计算时同时并行计算多个 Query token,所以只能通过 Batch Expansion 进行投机解码。

Batch Expansion

以 how [can, we] 举例,我们可以展开成 3 个请求:

图5:投机编码的 Batch Expansion 并行计算方法

从逻辑上来看,请求变多了,但 3 个请求放到一个 batch 中可以进行并行计算,可以共享 prefix cache (如果先做 prefill 的话),这样我们依然可以拿到 phow, pcan 和 pwe。通过并行请求也能够实现 1 次 Decode 验证多个 token。

这里要注意一个逻辑:虽然要验证 2 个 token,但是却展开成了 3 个请求。这样如果全部两个草稿模型投机推理的 token 都被接受了,那第 3 个 token 会由目标模型自己生成这个 token 被称为 bonus token

虽然 Batch Expansion 能解决投机编码时的并行问题,但 Batch Expansion 有一定的计算开销。在高吞吐的时候,会抵消投机编码带来的加速。更好的优化就需要 MLA 算子在单次前向计算时,同时 decode 2 个 query token,这有一定的改造成本。

5. MTP 并行和 overlap 优化

以 DeepSeek-V3 的参数规模,模型并行必不可少。尤其是考虑到微批计算、通信的 overlap 带来的高效率,MTP 的推理未必能像一般的草稿模型一样,单独执行。

很有可能需要将 MTP 的推理和 Speculative Decoding 的打分、验证融入到 DeepSeek-V3 模型中。通过一次前向计算,完成:1) 草稿 token 的打分和验证;2) 生成 token 的输出;3) 新草稿 token 的生成。类似于图 4,我就不画了。

结语

所以个人思考,DeepSeek-V3 MTP 的最优实现方式,很大可能是将 1 层与主模型融合在一起调度,而不是按照独立模型单独执行;在 PD 分离时由 Prefill 节点同时负责 MTP 的 prefill。

[1] EAGLE and EAGLE-2: Lossless Inference Acceleration for LLMs - Hongyang Zhang, https://www.youtube.com/watch?v=oXRSorx-Llg

[2] EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty, https://arxiv.org/abs/2401.15077

[3] DeepSeek-V3 Technical Report, https://arxiv.org/abs/2412.19437v1

[4] LLM PD 分离背后的架构问题, https://yangwenbo.com/articles/reflections-on-prefilling-decoding-disaggregation-architecture.html

[5] https://github.com/vllm-project/vllm/pull/12755

DeepSeek V3:AI 大模型 infra 基建新高度

AI 工程化

2021 年初,贾扬清在阿里云开始推广 AI 工程化这个概念,我非常认同。我在 21 年中做技术规划的时候,提出“AI 到生产力的转化,需要更高的工程化能力”,并且将 AI 工程化的实施总结为几大方向:

  • 语义索引场景化
  • 算力调度混合化
  • 模型研发标准化
  • 优化技术普惠化
  • 模型超大规模化
  • 架构系统智能化

我的 AI 工程化团队在这些方向上也取得了许多成果。

The AI Model

但 2022 年底 LLM 大流行以后,情况发生了一些变化。原因主要是 LLM 让 AI models 变成了 The AI model,虽然这个 model 很大,也多多少少有一些变种,但从工程实践的角度来看,它并不“复杂”。

很多时候,工程架构解决的是复杂性问题。

比如,TensorFlow、PyTorch、PaddlePaddle 这些训练框架简化了搭建和训练神经网络的复杂度,在一段时间内,各种结构的网络层出不穷,大部分都是依托这些框架来实现的。

而对于 LLM 来说,模型结构相对固定,虽然也使用了框架的一些外围能力,但是模型结构核心部分已经逐渐变成全手写以达成最佳性能,典型的实现包括 FlashAttention、TRT-LLM 等。

而且 LLM 的接口调用是自然语言,因而也变得极其简单,所有的 LLM 模型几乎可以使用同一套 API。

当时看起来 LLM 不需要太多的架构基建工作。

Prefix Caching 和 Kimi

我的这个认知在思考 prefix-caching 作用的时候,有了一些改变。

在《应该把 Prefix Caching 当作一种效果优化技术》这篇博客中,我提到 Prefix Cache Aware Scheduling 是一件非常值得做的事情。而且从 Kimi 发表的论文来看,他们已经在实践了,但其它的技术报告提到这些工程架构工作的不多。

DeepSeek V3

前几天 DeepSeek AI 发布了 DeepSeek V3 版本,我一边在吐槽这 670B 的模型参数太大,下载太慢,一边在阅读它的技术报告。结果发现他们在模型的部署上,玩得更高端,给了我一些新的震撼。

首先,prefilling 和 decoding 分开部署。prefilling 4 机 32 卡,decoding 40 机 320 卡。这样一来,我之前《LLM 推理优化 Continuous Batching 及其实现》这篇博客中提到的 Continuous Batching 就不再需要了。两阶段分开后,prefill 的计算过程(长度)是确定的,其算力利用是充分的,不再需要中间停下来插入新的请求。其实 prefilling 能够分开部署,跟 DeepSeek 以前的研究也是分不开的,DeepSeek V2 引入的 MLA 对 KV Cache 做了大幅度的低秩压缩,可以显著降低 KV Cache 从 prefilling 节点传递到 decoding 节点的带宽和延迟。

其次,MoE 专家分开部署。因为 MoE 专家的激活是 Token 级别的,也就是说每个 Token 会决定走哪个专家进行计算,分开部署就可能会带来负载均衡问题:有些专家太忙,有些专家太闲。DeepSeek V3 为了解决这个问题,还做了复杂的负载均衡策略。例如:快速识别较忙的专家,部署冗余的专家副本以承担压力;重新调整专家在不同节点的布局,尽量利用跨 GPU 带宽而减少跨节点带宽(因为 IB 比 NVLink 要慢);单卡冗余部署多专家,但通过全局的路由计算来优化专家的动态激活数量。

DeepSeek V3 还实现了计算和通信重叠。为了掩盖分布式计算过程中进行集合通信时的开销,将计算任务分为微批。一个微批进行集合通信时,进行下一个微批的计算。

此外,DeepSeek V3 在推理时还将 TP(Tensor)、DP(Data)、SP(Sequence)、EP(Expert)不同维度的并行化融合到了一起。单拿出来一种并行化方法也许现在的框架还是支持的,但这些方法组合在一起,我怀疑目前也没有什么推理加速框架能直接处理。

从技术报告中揭露的这些细节可以看出,为了发挥出模型的极致性能,DeepSeek 在 AI 大模型的分布式部署上花费了很大的心思。这也让 DeepSeek V3 成为目前公开资料可以看到的最复杂、最精巧的大模型 infra 设计

这些 idea 以前也许不是没有人想到,但是 infra 的演进是有很高研发和试错成本的。当 DeepSeek 将这些路走通以后,也许未来的很多大模型公司,大模型框架,都会往沿着这个方向继续演进。