在智谱的 GLM-4-0414 系列模型发布后,我观察到一个有意思的现象。GLM-4-0414 相较于 GLM-4 修改了模型结构,初期仅支持通过 transformers 库推理,但你去搜索它有哪些推理支持的话,什么也找不到,被大量的垃圾自媒体文章淹没,很多相较于官方模型仓库的 README 来说都是 0 信息量。
有些垃圾自媒体,可能连个 AI 都不如。
说回来 GLM-4-0414,虽然它名字看起来是 GLM-4 的延续,但是仔细看 config.json 你就会发现,GLM-4 的模型结构是 ChatGLMModel,而 GLM-4-0414 的模型结构是 Glm4ForCausalLM。这就导致很多推理框架都要对其进行重新适配。
vllm 可能是最早支持 GLM-4-0414 的开源框架,看起来是智谱员工提的 PR,但是第一版实现有两个 bug,会导致模型加载失败或者输出错误结果。由于对 GLM-Z1-9B-0414 模型在报告中声称大幅优于 DeepSeek-R1-Distill-Qwen-14B 过于好奇,我忍不住自己尝试在 SGLang 里适配了一下 GLM4-4-0414,见 PR#5485。
嗯,学到了很多东西,也踩了两个 bug 其中的一个,哭。就说说这两个 bug 吧,其实都是源自于对 vllm/SGLang 库算子的不了解。
在 transformers 库中,很多算子仅实现其算子名表示的朴素的功能,但是 vllm/SGLang 代码库的一些算子,除了其朴素功能以外,往往还通过更多的参数实现对前后计算的融合、简化,导致其实际计算需要深入探究。
BUG 1: 融合 RMSNorm
以比较基础的 LLaMA 为例,transformers 中的 LLaMA DecoderLayer 的实现是这样的(删去了一些无关细节),这个跟大家理解的 Transformer 模型结构伪代码是容易对上的。
def forward():
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn()
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
return outputs
但是 vllm/SGLang 中的实现是这样的:
def forward():
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(...)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
粗看你会有点懵,仔细研究就会发现,SGLang 里面通过给 RMSNorm 传入第二个参数,实现了 Norm 与 Add 的融合。但是这种融合需要调整计算顺序,影响了参数和返回值的类型,并且也影响了最后一次 model.norm() 的计算。
相似的还有 SiluAndMul() 算子。
BUG 2: get_rope 参数
GLM-4-0414 的 config.json 中提供了 partial_rotary_factor 参数,作用于 head_dim 上。有两种应用方法,一种是提前计算好 rotary_dim = int(partial_rotary_factor * self.head_dim),然后把这个参数传进去;另一种是令 rotary_dim = head_dim,然后传入 partial_rotary_factor。
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
...
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
我重复犯的这个错误就是将 partial_rotary_factor 应用了两遍,rotary_dim 也计算了,partial_rotary_factor 参数也传了(因为参考了 vllm 实现和旧的 SGLang chatglm 实现,甚至看到别人提交的 bugfix 都不认为这里有错),其实就相当于应用了 partial_rotary_factor * partial_rotary_factor。这个 BUG 的现象就是导致 GLM-4 模型的输出大概率陷入死循环而无法结束。
BTW:transformers 库里的 glm4 代码没有读这个参数,而是硬编码在 apply_rotary_pos_emb() 算子中,所以这不是一个可调参数。
PR#5485 还在 Review 中,有需要的朋友可以使用 https://github.com/solrex/sglang/tree/glm4 这个分支支持 GLM-4-0414 系列模型在 SGLang 的推理。
git clone -b glm4 https://github.com/solrex/sglang.git
cd sglang
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
python3 -m sglang.launch_server --model /workspace/GLM-Z1-9B-0414 --enable-torch-compile --torch-compile-max-bs 128 --cuda-graph-max-bs 128 --host 0.0.0.0 --port 8000
《有关 GLM-4-0414 的 SGLang 推理支持》上有1条评论