我提交的 PR: Support serving DeepSeek-R1-Channel-INT8 with 32 L40S. #4418 [1] 已经合入到了 SGLang 的主干,也许这是第一个用 PCIE 互联的 GPU 小卡跑通 DeepSeek-R1 推理的例子。
有一些遇到的问题分享一下,在适配别的 GPU 时可以用来参考。
背景
因为 128x128 block 量化的 DeepSeek-R1/V3,参数维度在除 32 时,会遇到商无法被 128 整除的问题,所以即使 L40S 支持 FP8,也无法直接用 TP32 运行 DeepSeek-R1/V3。
感谢美团提供了 channel 量化的 DeepSeek-R1 参数 DeepSeek-R1-Channel-INT8 [2],并且在 SGLang 代码库做了适配。这让我可以在 48G 显存的 L40S 上尝试一下运行满血版(int8 量化) DeepSeek-R1,但尝试过程没有我预期的顺利,遇到了不少问题。
问题
一、shared memory OutOfResources
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/triton_ops/extend_attention.py", line 356, in extend_attention_fwd
_fwd_kernel[grid](
...
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 374, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 102400, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
服务启动正常,收到推理请求后,Attention kernel 报 shared memory 资源不足。我刚开始没理解这个问题,后来看了下代码,发现 triton attention 代码中对不同类型的 GPU 架构设置了不同的 block size。L40S 属于 SM89 架构,但 SM89 都归类到了 SM80 架构里。
查了一下 CUDA 编程手册,SM89 的 shared memory 大小是 100K,但 SM80 是 160K。我猜测大概率就是这个问题,所以给 SM89 单独加了个分支,缩小了 SM89 的 block size,解决了这个问题。
后来发现 sglang/test/srt/test_triton_attention_kernels.py 可以完美复现这个问题,而我却傻乎乎地每次重启整个服务去测试正确性。
二、gemm executioin failed RuntimeError
File "/usr/local/lib/python3.10/dist-packages/sgl_kernel/ops/__init__.py", line 118, in int8_scaled_mm
return torch.ops.sgl_kernels.int8_scaled_mm(
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: gemm executioin failed, error: Error Internal
收到推理请求后,有很大的概率触发这个错误,但也有小概率能完成一次推理。所以我没想到是算子问题,我以为是显存不足。调了半天各种显存占用的参数,后来没办法了才回过头来看实际的算子调用。
又是跟上面类似的问题,sgl_kernel 中自定义的 int8 gemm 算子将 SM89 归类到 SM80 进行矩阵计算的 dispatch。这显然会遇到与上面类似的问题,但是我又不知道 SM89 该怎么进行 dispatch,看起来需要做很多 benchmark 或者计算才能确定。
于是我就去翻 TensorRT-LLM 和 vLLM,让我给翻到了 vLLM 的实现,我就照着 vLLM 对 SM89 的 dispatch 逻辑抄了一遍。这次我学乖了,先看看有没有 test。跑通了 sglang/sgl-kernel/tests/test_int8_gemm.py,才去进行集成测试。
三、sub-optimal MoE
Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json
这是一个次要问题,看起来是一个专门的配置没有找到。后来我研究了一下,应该是 triton 版本的 fused_moe 需要读取一个在每种类型的 GPU 上都 benchmark 过的最好配置来运行。
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model /workspace/DeepSeek-R1-Channel-INT8 --tp-size 32 --dtype int8_w8a8 --tune
我按照文档的要求,在 L40S 上跑了一下 bench,把最终输出的文件拷贝到对应的位置就好了。这个 bench 真的要跑好久,大概两三个小时。
性能
我这 4 台 L40S 的硬件配置有些特殊:它是单机 8 卡,PCIE 4.0 连接到主机,但它既不是 2-2-4,也不是 2-2-8。
20250322:其实是硬件配置错误,详见:《刷新 32 张 L40S 运行 DeepSeek-R1-INT8 的性能数据》
其中 4 张卡直插 PCIE 上,另外 4 张卡通过一个 PCI-to-PCI Bridge 插到 PCIE 上。为了弥补这样连接的带宽缺陷,PCI-to-PCI Bridge 上还接了 2 个 100Gb/s 的 RDMA 网卡。主机上每个 NUMA 域,分别也有 2 个 100Gb/s 的 RDMA 网卡。
所以这 8 张显卡,6 张网卡,拓扑如下所示:

我也实在算不出来这玩意儿的互联带宽。逻辑上来说,这大概相当于用 4 PCIE 4.0 GPU + 2 RDMA 网卡的性能,所以也许这个 4 x 8卡,实际上相当于 8 x 4卡。以下性能评测结果供参考。
下面是加载时的显存使用情况:
[TP0] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=22.43 GB, mem usage=21.27 GB.
[TP0] Memory pool end. avail mem=7.95 GB
[TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=7.92 GB
[TP0] Capture cuda graph end. Time elapsed: 411.41 s. avail mem=5.89 GB. mem usage=2.02 GB.
[TP0] max_total_num_tokens=201723, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=2049, context_len=163840
下面是使用固定 200 token 输入时,不同并发下的净 decode 速度。因为 bench 速度太慢,我也懒得等最终的结果,所以这里直接节取了 log。
# 128 并发
[TP0] Decode batch. #running-req: 128, #token: 86816, token usage: 0.43, gen throughput (token/s): 565.73, #queue-req: 0,
# 32 并发
[TP0] Decode batch. #running-req: 32, #token: 8923, token usage: 0.04, gen throughput (token/s): 260.73, #queue-req: 0,
# 4 并发
[TP0] Decode batch. #running-req: 4, #token: 1439, token usage: 0.01, gen throughput (token/s): 42.30, #queue-req: 0,
# 1 并发
[TP0] Decode batch. #running-req: 1, #token: 482, token usage: 0.00, gen throughput (token/s): 26.00, #queue-req: 0,
这是固定 200 输入,200 输出,128 并发,2 轮请求的完整压测结果:
============ Serving Benchmark Result ============
Backend: sglang-oai
Traffic request rate: 128.0
Max reqeuest concurrency: 128
Successful requests: 256
Benchmark duration (s): 130.79
Total input tokens: 51200
Total generated tokens: 51200
Total generated tokens (retokenized): 50992
Request throughput (req/s): 1.96
Input token throughput (tok/s): 391.47
Output token throughput (tok/s): 391.47
Total token throughput (tok/s): 782.94
Concurrency: 127.49
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 65135.99
Median E2E Latency (ms): 64974.27
---------------Time to First Token----------------
Mean TTFT (ms): 17554.19
Median TTFT (ms): 19216.02
P99 TTFT (ms): 21662.98
---------------Inter-Token Latency----------------
Mean ITL (ms): 239.84
Median ITL (ms): 220.95
P95 ITL (ms): 233.20
P99 ITL (ms): 299.49
Max ITL (ms): 16077.21
==================================================
链接
[1] https://github.com/sgl-project/sglang/pull/4418