应该把 Prefix Caching 当作一种效果优化技术

我在 4 月份写的博客《LLM 推理优化 Prefix Caching 及其实现》带来了不少流量,这篇博客在 Google 搜索“Prefix Caching”时排在比较靠前的位置,也给我的微信公众号“边际效应”带来了超过 100 个关注者,由此可以看到大家对这项技术的广泛兴趣。

新认知

最近关于 Prefix Caching 我又产生了一次认知升级:在一个 Eureka 时刻,我忽然领悟到 Prefix Caching 实际上可以是一种效果优化手段——但让其充分发挥效能需要跟长上下文语言模型 Long-Context LLM 和基于 Prefix Cache 的调度技术相结合

RAG v.s. Super Long Domain Specific Prefixes

为了减少模型的幻觉,提升时效性,我们往往会采取 RAG 技术来增强模型。召回的结果本身会增加 Prompt 的长度,会加大 Context 的计算量,会影响模型的推理时长,因而召回结果的补充也是谨慎克制的,不会让 Prompt 变得非常长。

但是 Prefix Caching 优化技术给我们开了个口子:如果把信息放在 Prompt 的共享 Prefix 中,加长 Prompt 的代价就由计算和时延,转化到了存储。而存储代价可以通过 Cache 复用来摊薄,这就是一笔很划算的经济账。

如果把 RAG 召回的结果当作 Prompt 里的变量,那么 Prefix 就是 Prompt 里的常量,当增加变量信息的规模受到限制时,增加常量信息的规模也可以提升生成的效果。

举个例子:如果你做 K12 领域的模型,把 12 个年级的语文知识点都放在 Prefix 里,然后再叠加 RAG,然后回答用户语文方面的提问,肯定能更好地保证生成的效果。

所以,LLM 技术后续的一个新发展分支,也许会是超长特定领域专用的前缀设计

Inference Time Scaling

最近讨论很多的一个技术方向,是从基于预训练提升效果的 Scaling Law 转向基于 reasoning at inference time 的 Scaling Law,使用更多的推理计算来获得更好的效果。

但这些 Scaling Law 还是基于“computing”的 Scaling,我认为也许会存在基于“memory”的 Scaling Law,即更长的共享 Domain Specific Prefix 带来更好的效果。因为这里的“memory”是计算的缓存,memory 的共享本质上是计算的共享。

Long Context Large Language Model

在无法共享计算的场景下,长上下文的大语言模型应用往往由于其计算成本或显著延迟而无用武之地。但如果基于 memory 的 Scaling Law work 的话,那长上下文大语言模型将成为必经之路

这也许是 Moonshot 和 Gemini 早早投入百万 Token 级别长上下文模型的背后逻辑

Prefix Cache Aware Scheduling

在没有理想的大一统 Prefix 之前(大概率也不可能),共享 Prefix 只能是 Domain Specific 才最合适。那就意味着会有很多个共享 Prefix,而 Prefix Cache 占用的空间又是不可忽视的,所以不可能在每张卡/每台机器上都存储所有的共享 Prefix。

这就要求在用户请求调度时,根据 Prefix 分配到不同的卡/机器/集群上。Prefix Cache 也需要能够根据请求的热度在不同的服务器间弹性扩散或者收缩,甚至在显存、内存、SSD 之间调度。其实 Kimi 的论文“Mooncake: A KVCache-centric Disaggregated Architecture for LLM Serving”就是在解决这个问题,只是我之前没有意识到这个技术的广泛重要性

LLM 推理优化 Prefix Caching 及其实现

Prefix Caching

上文中提到,Prompt 计算(Prefill 阶段)和生成阶段的计算特性很不相同。为避免重复计算,所有框架的 prefill 阶段主要作用就是给迭代的生成阶段准备 KV Cache。但这些 KV Cache 仅仅是为单次生成式请求服务的,那很自然的一种想法就是,KV Cache 能不能跨请求复用?

在某些场景下,多次请求的 Prompt 可能会共享同一个前缀(Prefix),比如拟人 Agent 的人物设定,文档阅读理解时的文档内容等。这些情况下,很多请求的前缀的 KV Cache 计算的结果是相同的,像一般互联网服务的请求缓存一样,可以被缓存起来,给下一个请求复用。

限制

但 KV Cache 跟其它服务缓存不一样的地方是,它太大了,以至于(目前)很难通过 Redis/Memcache 这种分布式缓存服务存取。比如对 13B LLM 模型来说,在 FP16 精度下单 token 的 KV Cache 大约是 1MB,假设要缓存的前缀有 500 个 token(大约800多个汉字),那就是 500MB。一般来说,我们不会每次请求去从分布式系统里读取/传输 500MB 的缓存,甚至都不会每次请求从内存往显存中拷贝 500MB 的缓存,所以大部分情况下,prefix cache 都会放在显存里。

这也就意味着,如果你想命中 prefix cache,必须把相同 prefix 的请求发到同一张 GPU卡上才行。

实现

由于不是普遍需求,加上前面说的限制,prefix caching 作为一个加速特性,不是很受关注,一般也不是默认开启的。各框架的实现和配置略有差异,这里简单做下记录,便于回顾。

刚开始 vLLM 的实现是给 generate 接口增加一个 prefix_pos 参数,通过 prefix_pos 输入参数为每个请求指定 prefix 长度,为 prefix 建一个带淘汰的哈希缓存。后来觉得这样做使用上不够便利,升级成了自动前缀缓存,即将 prompt 的 kv cache 分成 block,然后为 block 建设 LRU 缓存机制,这样就不必在接口上使用 prefix_pos 指定哪部分是 prefix 了。自动前缀缓存功能默认是不开启的,开启的配置项为 --enable-prefix-caching

TensorRT-LLM 与 vLLM 后来的实现类似,也是实现了 block kv cache,配置项是 enableBlockReuse,默认也是不开启的。代码未开源,无法看到实现。

Lmdeploy 的 PythonTurboMind C++ 版本的 prefix caching 功能都已经有了 PR,但现在(20240425)看还没有合入主干。有意思的是它没有使用 hash block 对应 token_id 子串的所有 token_id 前缀然后组成哈希表的方式,而是用 hash 当前 block 对应的 token_id 子串然后组成 trie 树的缓存管理结构。默认的参数名与 vLLM 相同,也叫做 --enable-prefix-caching。

HuggingFace TGI 现在看起来还没实现 Prefix Caching 功能。

Prompt Caching

除了 Prefix Caching 这种比较直观的工程优化,现在也有一些研究在看 Prompt 的其它缓存机制。比如设计一种机制让 prompt 模块化,不仅可以复用 prefix,还能复用中间的部分;或者通过 query 的相似性复用其它 query 的 prompt

但目前看实现上都过于复杂,比如第一种要求模型使用不连续的 poition_id,这样就可以插入 token,但这种方式对 attention 的计算机制有一定的影响,难以说明它对效果的影响。