SgLang代码细读-3.Cache
<h2 id="本地显存池">本地显存池</h2><h3 id="数据结构">数据结构</h3>
<p>因为kv cache有MHA,MLA,DoubleSparse 等多种自定义类型,需要进行一步抽象将框架和cache类型做隔离, 所以有了2级内存池的设计. 一级保存和cache类型无关的数据(token位置),跟具体业务隔离,二级给出抽象类接口, 不同的cache类型按需继承实现interface, 就能通过配置来进行管理.</p>
<h4 id="二级显存池">二级显存池</h4>
<h5 id="req_to_token_pool"><code>req_to_token_pool</code></h5>
<pre><code class="language-python">class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size #size对应的是server_args.max_running_requests
self.max_context_len = max_context_len#对应的是从模型配置里读出来的支持最大的上下文长度
self.device = device
with memory_saver_adapter.region():
self.req_to_token = torch.zeros( #2维, 第一维偏移代表是第几个req, 第二位偏移记录在req中token在二级池的索引
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size)) #1维, 用于记录哪些req被释放掉了, 在后续的请求可以复用
</code></pre>
<h5 id="token_to_kv_pool"><code>token_to_kv_pool</code></h5>
<p>功能: 将 token 的 KV Cache索引映射到其 KV Cache数据, 实际的实现中这个依然是2大类组合形成的, 包括<code>PagedTokenToKVPoolAllocator</code>与<code>KVCache</code>接口类和其对应的子类 (只看了page_size>1的实现)</p>
<p><code>PagedTokenToKVPoolAllocator</code>主要负责kv分页后的页表管理, 存储的数据是<code>free_pages</code>, 假设page_size=4. 初始化状态如下:</p>
<pre><code class="language-markdown">+---------+---------+---------+---------+
| Page 1| Page 2| Page 3| Page 4|
| 4~7 | 8~11 | 12~15 | 16~19 |
+---------+---------+---------+---------+
free_pages:
alloc:
分配8个token后:
free_pages:
分配到的索引:
free:
传入要回收的 token 索引(如 ),会通过idx / page_size转换为页索引 ,并加回free_pages,变为
</code></pre>
<p>KVCache子类, 以MLA为例, kv_buffer为<code>layer_num</code>个torch.Tensor, 存储了<code>k_buffer</code> 的 <code>cache_k</code> 和 <code>v_buffer</code> 的 <code>cache_v</code>, 每个tensor的dim分别表示: (最大token数 + page_size, head_dim(MLA里是1),head维度)</p>
<p>在MLA里, head维度是 LoRA相关的KV维度(低秩适配部分) + QK经过RoPE后的维度</p>
<p>第一维要加page_size的原因是: 如果不加, 某些操作(如<code>buffer</code>)会越界, 加上page_size后可以避免这些越界情况的判断, 简化逻辑</p>
<pre><code class="language-python"> with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
</code></pre>
<h4 id="hostkvcache">HostKVCache</h4>
<p>Hierarchical Caching(分层缓存)机制, 支持一部分kvcache通过offload方式放到内存里. 由于会影响推理速度暂没用到. 待有需求的时候再细看.</p>
<h3 id="显存allocfree">显存alloc/free</h3>
<h5 id="alloc">alloc</h5>
<p>从二级显存池申请空间逻辑都在forwardBatch.prepare_for_extend/prepare_for_decode里面, 以extend为例, 分为几步:</p>
<ol>
<li><code>alloc_req_slots</code>: 根据batch_size, 从req_to_token_pool中申请bs个free_req对应的token索引.</li>
<li>遍历reqs, 把刚才申请到的<code>req_pool_indices</code>填到对应req的<code>req_pool_idx</code>成员里, 使其能够一一对应
<ol>
<li><code>get_last_loc</code>: 获取每个请求前缀最后一个token在req_pool_indices中的索引</li>
<li><code>token_to_kv_pool_allocator.alloc_extend</code>: 计算逻辑在alloc_extend_kernel里,
<ul>
<li>第一步: 因为之前算出了最后一个token在显存中的偏移, 根据这个偏移和page_size能拿到最后token所在页和还有多少剩余空间, 先把这页没满的空间填满.</li>
<li>第二步: 从free_page里拿出新页继续填充</li>
<li>第三步: 分配最后一页, 如果填不满, 就把剩下的token填到这一页的前几个里面</li>
</ul>
</li>
</ol>
</li>
</ol>
<h5 id="free">free</h5>
<p>在Cache复用中决定这些cache何时被回收, 通过调用<code>token_to_kv_pool_allocator.free</code>和<code>req_to_token_pool.free</code>处理. 核心逻辑:</p>
<pre><code class="language-python">free_page_indices = torch.unique(free_index // self.page_size)#会把所有 token 索引转换为页号(同一页的 token 都会变成同一个页号)。
self.free_pages = torch.cat((free_page_indices, self.free_pages))#塞回free_pages
</code></pre>
<p>只要页内有任意 token 被回收,这一整页就会被回收</p>
<h3 id="kvcache读写">KVCache读写</h3>
<p>在attn_backend的forward函数中</p>
<p>P阶段把kv写到kv_buffer里, 即根据cache_loc到KVCache子类中对应的偏移中将torch.Tensor的值复制过去</p>
<p>D阶段根据layer_id读出对应layer的cache.</p>
<pre><code class="language-python"> if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
# Call the wrapped function
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
</code></pre>
<h2 id="cache复用">Cache复用</h2>
<h3 id="radixcache">RadixCache</h3>
<h4 id="数据结构-1">数据结构</h4>
<p>基于基数树<code>RadixTree</code>数据结构实现的Cache, 其实就是压缩版的前缀树. 一看图就能弄清楚:</p>
<img src="https://img2023.cnblogs.com/blog/1439743/202505/1439743-20250522161641192-1536091629.png" alt="image-20250521174326143" style="zoom: 50%">
<p>查询: 一直DFS到没有公共前缀为止.</p>
<p>插入:以root为起点遍历, 对当前节点做前缀匹配, 长度>0就进入子树否则进入兄弟节点. 一直DFS到没有公共前缀为止, 把不相同的str插入到新叶节点上.</p>
<h4 id="插入与驱逐">插入与驱逐</h4>
<img src="https://img2023.cnblogs.com/blog/1439743/202505/1439743-20250522161646030-100484665.png" alt="image-20250521173520016" style="zoom: 50%">
<p>前缀匹配代码解析:</p>
<pre><code class="language-python"> def _match_prefix_helper(self, node: TreeNode, key: List): #传入的node就是root
node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys(): #非递归版dfs, 非page时当key中的第一个不在node的child中退出.(即完全不匹配)
child = node.children
child.last_access_time = time.time()
prefix_len = self.key_match_fn(child.key, key) #树节点和当前token_id list进行前缀匹配
if prefix_len < len(child.key): #部分匹配
new_node = self._split_node(child.key, child, prefix_len) #分裂不匹配的那部分, 挂到当前节点下面作为child
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value) #完全匹配, 进入子节点继续遍历, 把已经匹配成功的节点加到结果里
node = child
key = key #去掉已经匹配过的前缀
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
</code></pre>
<p>驱逐使用了引用计数(lock_ref)用于记录当前cache有没有在使用, 当叶子的引用计数为0时可以驱逐释放. 参考函数<code>dec_lock_ref</code>, 注意这里lock_ref在减这个node时, 会把他的所有父节点路径全都减1. 驱逐代码解析:</p>
<pre><code class="language-python"> def evict(self, num_tokens: int):
if self.disable:
return
leaves = self._collect_leaves() #通过BFS方式获取到树上的所有节点
heapq.heapify(leaves) #把树list转成堆, 通过TreeNode中的__lt__进行比较排序, 其实就是比last_access_time
num_evicted = 0
while num_evicted < num_tokens and len(leaves): #循环pop heap
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.lock_ref > 0: #引用计数>0的叶子跳过
continue
self.token_to_kv_pool_allocator.free(x.value)#释放ref_count=0的kvcache
num_evicted += len(x.value)
self._delete_leaf(x) #在树上删掉这个叶节点
if len(x.parent.children) == 0: #如果这个叶节点的父节点, 被删除这个child后也变成了叶节点, 把他push进heap
heapq.heappush(leaves, x.parent)
</code></pre>
<h4 id="cache_requestreq">cache_request(req)</h4>
<ol>
<li>从<code>req_to_token_pool.req_to_token</code>获取<code>kv_indices</code></li>
<li>把当前这条请求更新到Radix Cache (<code>insert()</code>)</li>
<li>finished: 释放这条请求的KV Cache, unfinished: 更新这条请求在req_to_token_pool中的偏移</li>
<li>finished: 把这条请求的last_node引用计数-1, 标识可以evict, unfinished: 如果开了page, 把req里的last_node 引用计数-1, 把页对齐的last_node 引用计数+1</li>
</ol>
<h3 id="chunkcache">ChunkCache</h3>
<p>对于过长的token请求, 如果在一个batch内处理除了会及其占用显存资源导致显存超限外, 还有可能因为单请求无法并行处理严重影响其他请求的TTFT, 所以有了chunked_prefill这个功能, 主要作用就是将过长请求切分成多个chunk分别进行处理</p>
<p>sglang的实现在同一时期只能有一个请求在chunk, 而chunk请求在处理时和其他请求的不同点在于: 当前chunk在进行attention计算时, 需要依赖此前的chunk计算的kvcache. 如下图:</p>
<img src="https://img2023.cnblogs.com/blog/1439743/202505/1439743-20250522161651699-644053946.png" alt="image-20250521111843935" style="zoom: 50%">
<p>因此就有ChunkCache这么个东西, 专门用来处理图中绿色部分的kvcache.</p>
<p>cache_unfinished_req: 把没处理完的req当前chunk的显存池的偏移量取出来, 塞到prefix_indices里用于下一个chunkReq的构建.</p>
<p>cache_finished_req: 把当前chunk和之前的prefix chunk kvcache直接全部free掉</p>
<h2 id="pd分离kvcache通信">PD分离,KVCache通信</h2>
<p>相关代码在python/sglang/srt/disaggregation中. 包括5种class:</p>
<ol>
<li><code>TransferBackend</code>: 枚举类, 用于记录server_args指定的kvTransfer后端, 把不同的KVCache通信后端封装到相同的接口内方便框架兼容(mooncake/nixl)</li>
<li><code>BaseKVManager</code>: 抽象类接口, 每个后端自己实现. 管理KV通信线程, 以及P和D的连接关系. 绑定ZMQ用于D节点和P节点的TCP通信. P节点有两个ZMQ监听线程(Bootstrap和transfer), D节点只有一个decode线程.</li>
<li><code>BaseKVBootstrapServer</code>: 抽象类接口, 每个后端自己实现. 用于P节点接收 D节点alloc完成后发送的Notify请求. 在这个类中起一个新线程, 通过event_loop监听一个端口接收请求.</li>
<li><code>KVSender</code>: 抽象类接口, 用于P节点的请求发送(send接口)和状态查询(poll)</li>
<li><code>KVReceiver</code>: 抽象类接口, 用于D节点的请求接收(recv接口)和状态查询(poll)</li>
</ol>
<p>注意在代码中会看到除了KV本身还有一类叫aux data, 是 <strong>auxiliary</strong> 的缩写,表示“辅助数据”, 比如位置编码、mask、attention map、LoRA 相关参数等.</p>
<h3 id="建立连接">建立连接</h3>
<ol>
<li>
<p>P节点注册自身信息到<code>MooncakeKVBootstrapServer</code></p>
<ul>
<li>每个P 节点启动时,会通过 HTTP PUT 请求,把自己的 rank_ip、rank_port等信息注册到 bootstrap server。</li>
<li>bootstrap server 会把这些信息按 DP/TP 分组存入prefill_port_table,以便后续 decode 查询</li>
</ul>
</li>
<li>
<p>D节点查询需要连接的P节点</p>
<ul>
<li>D节点初始化时,会根据自己的 engine_rank、dp_group 等参数,通过 HTTP GET 请求向 bootstrap server 查询自己应该连接的 P 节点信息 <code>_get_bootstrap_info_from_server</code></li>
<li>查询参数为 engine_rank 和 target_dp_group,bootstrap server 返回对应 P节点的 IP 和端口</li>
</ul>
</li>
<li>
<p>D节点拿到P节点的 IP/端口后,通过 ZeroMQ 建立 socket 连接, 然后把自己的KVCache相关信息通过 ZeroMQ 发送给P节点(KVReceiver里init方法),完成后续的数据同步和传输</p>
</li>
</ol>
<h3 id="发送数据">发送数据</h3>
<pre><code class="language-python">def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = ( #从显存池里拿到token KV Value的起始地址
self.token_to_kv_pool.get_contiguous_buf_infos()
)
kv_args.kv_data_ptrs = kv_data_ptrs #把这个显存地址传到BaseKVManager里用于初始化
kv_args.kv_data_lens = kv_data_lens #从而在send的时候只传kv_indices, TranferEngine也能知道从哪里取出kv value
kv_args.kv_item_lens = kv_item_lens
#...
</code></pre>
<ol>
<li>P节点从bootstrap队列拿出已经变成WaitForInput状态的请求, 即到达下图的Notify收到的状态. 初始化<code>req.disagg_kv_sender</code></li>
<li>完成forward后, 通过<code>send_kv_chunk->disagg_kv_sender.send()</code>函数把<code>kv_indices</code> copy到内存添加到通信队列里.</li>
<li>tranfer异步线程从队列里取KVCache索引, <code>send_kvcache</code>, <code>group_concurrent_contiguous</code>先把连续的内存块进行切分. 根据每个layer从线程池里取一个独立的线程进行并发通信, 最后等所有layer完成通信. (engine.transfer_sync 调用的是mooncake的内部方法, 需要后续细看mooncake代码, 记个TODO)</li>
</ol>
<h3 id="接收数据">接收数据</h3>
<p>之前在notify的时候把D节点要接收的显存地址也传过去了, 通过TransferEngine直接完成从显存到显存的copy.</p>
<p>P节点在通信完成后调用<code>sync_status_to_decode_endpoint</code>, 通过ZMQ告知D节点完成传输</p>
<p>D节点的<code>start_decode_thread</code>在收到传输完成通知后, 更新status. 这样就完成了KVCache的整体传输</p>
<img src="https://img2023.cnblogs.com/blog/1439743/202505/1439743-20250522161700884-1116360292.png" alt="image-20250521205759880" style="zoom: 50%">
<p>参考: https://zhuanlan.zhihu.com/p/31160183506</p>
<p>Sglang kvcache code walkThrough: https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/sglang/kvcache-code-walk-through/readme-CN.md</p>
<p>MLA细节解析: https://zhuanlan.zhihu.com/p/19585986234</p>
<p>RadixAttenion解析: https://zhuanlan.zhihu.com/p/693556044</p><br><br>
来源:https://www.cnblogs.com/sunstrikes/p/18891538
頁:
[1]