Optimize Seko AR KV cache#1109
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces KV cache buffer reuse to avoid reallocation, adds a fused Triton kernel for KIVI dequantization, and refactors several model runners to reuse the cache manager. Feedback on these changes highlights critical issues: replacing torch.zeros with torch.empty and removing zero_() in reset() introduces risks of NaN propagation and sensitive data leakage (CWE-226) across sessions. Additionally, using a simple boolean check for kv_quant in the buffer signature can cause incorrect cache reuse when quantization configurations change, and calling torch.distributed.barrier without an explicit sequence parallel group can lead to deadlocks in distributed environments.
| def reset(self) -> None: | ||
| self._k_buffer.zero_() | ||
| self._v_buffer.zero_() | ||
| pass |
There was a problem hiding this comment.
Removing zero_() in reset() and switching to torch.empty introduces two major risks:
- NaN Propagation / Correctness Issues:
torch.emptyallocates uninitialized memory which can contain arbitrary garbage, includingNaNorInfvalues. If the attention mechanism or any masking operation performs computations over the entire buffer (even if masked out by multiplying with 0, since0 * NaN = NaN), theseNaNs can propagate and corrupt the model's outputs. - Data Leakage (Security Vulnerability - CWE-226): Since the
KVCacheManagernow reuses cached KV buffers across different inference requests (by matchingbuffer_sig), not zeroing out the buffers onreset()means that sensitive data from a previous request remains in memory. If a subsequent request (potentially from a different user) reads beyond its current sequence length due to a bug, speculative decoding, or different sequence lengths, it could leak information from the previous request.
Recommendation: Revert to using torch.zeros or ensure that reset() explicitly zeroes out the buffers (e.g., using zero_()) to maintain correctness and security.
| def reset(self) -> None: | |
| self._k_buffer.zero_() | |
| self._v_buffer.zero_() | |
| pass | |
| def reset(self) -> None: | |
| self._k_buffer.zero_() | |
| self._v_buffer.zero_() |
| def reset(self) -> None: | ||
| if not self._kv_offload: | ||
| self._k_buffer.zero_() | ||
| self._v_buffer.zero_() | ||
| self._global_end.zero_() | ||
| self._local_end.zero_() | ||
| self._init_ring_metadata() | ||
| return | ||
| self.sync_all() | ||
| self._k_cpu.zero_() | ||
| self._v_cpu.zero_() | ||
| self._k_gpu_buf.zero_() | ||
| self._v_gpu_buf.zero_() | ||
| self._global_end.zero_() | ||
| self._local_end.zero_() | ||
| self._init_ring_metadata() |
There was a problem hiding this comment.
Removing zero_() in reset() and switching to torch.empty introduces two major risks:
- NaN Propagation / Correctness Issues:
torch.emptyallocates uninitialized memory which can contain arbitrary garbage, includingNaNorInfvalues. If the attention mechanism or any masking operation performs computations over the entire buffer (even if masked out by multiplying with 0, since0 * NaN = NaN), theseNaNs can propagate and corrupt the model's outputs. - Data Leakage (Security Vulnerability - CWE-226): Since the
KVCacheManagernow reuses cached KV buffers across different inference requests (by matchingbuffer_sig), not zeroing out the buffers onreset()means that sensitive data from a previous request remains in memory. If a subsequent request (potentially from a different user) reads beyond its current sequence length due to a bug, speculative decoding, or different sequence lengths, it could leak information from the previous request.
Recommendation: Revert to using torch.zeros or ensure that reset() explicitly zeroes out the buffers (e.g., using zero_()) to maintain correctness and security.
| def reset(self) -> None: | |
| if not self._kv_offload: | |
| self._k_buffer.zero_() | |
| self._v_buffer.zero_() | |
| self._global_end.zero_() | |
| self._local_end.zero_() | |
| self._init_ring_metadata() | |
| return | |
| self.sync_all() | |
| self._k_cpu.zero_() | |
| self._v_cpu.zero_() | |
| self._k_gpu_buf.zero_() | |
| self._v_gpu_buf.zero_() | |
| self._global_end.zero_() | |
| self._local_end.zero_() | |
| self._init_ring_metadata() | |
| def reset(self) -> None: | |
| if not self._kv_offload: | |
| self._k_buffer.zero_() | |
| self._v_buffer.zero_() | |
| self._global_end.zero_() | |
| self._local_end.zero_() | |
| self._init_ring_metadata() | |
| return | |
| self.sync_all() | |
| self._k_cpu.zero_() | |
| self._v_cpu.zero_() | |
| self._k_gpu_buf.zero_() | |
| self._v_gpu_buf.zero_() | |
| self._global_end.zero_() | |
| self._local_end.zero_() | |
| self._init_ring_metadata() |
| self._k_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() | ||
| self._v_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() | ||
| self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) | ||
| self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) |
There was a problem hiding this comment.
Using torch.empty instead of torch.zeros can lead to silent correctness issues and NaN propagation. Uninitialized memory allocated by torch.empty can contain arbitrary garbage, including NaN or Inf values. If the attention mechanism or any masking operation performs computations over the entire buffer (even if masked out by multiplying with 0, since 0 * NaN = NaN), these NaNs can propagate and corrupt the model's outputs.
Recommendation: Revert to torch.zeros to ensure that the buffers are safely initialized.
| self._k_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() | |
| self._v_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() | |
| self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) | |
| self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) | |
| self._k_cpu = torch.zeros(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() | |
| self._v_cpu = torch.zeros(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() | |
| self._k_gpu_buf = torch.zeros(N, H, D, dtype=self._dtype, device=self._device) | |
| self._v_gpu_buf = torch.zeros(N, H, D, dtype=self._dtype, device=self._device) |
| head_dim = self.config["dim"] // self.config["num_heads"] | ||
| buffer_sig = ( | ||
| int(self.kv_size), | ||
| int(self.cache_num_heads), | ||
| int(self.config["num_layers"]), | ||
| int(head_dim), | ||
| str(self.dtype), | ||
| int(self.frame_seq_length), | ||
| bool(self.ar_config.get("kv_quant")), | ||
| bool(self.ar_config.get("kv_offload")), | ||
| self.config.get("infer_steps"), | ||
| ) |
There was a problem hiding this comment.
Using bool(self.ar_config.get("kv_quant")) in buffer_sig is highly risky and can lead to crashes or silent data corruption.
If kv_quant is a dictionary (e.g., {"quant_scheme": "kivi"}), bool(kv_quant) evaluates to True. If a user changes the quantization settings between runs (for example, switching quant_scheme from "kivi" to "longlive_fp4", or changing the number of bits/group size), bool(kv_quant) remains True. As a result, buffer_sig will match, and the manager will reuse the existing self_attn_kv_cache instance of the wrong class or configuration, leading to AttributeError or severe runtime corruption when incompatible quantization kernels are executed.
Recommendation: Include a hashable representation of the kv_quant configuration in buffer_sig instead of just its boolean value.
head_dim = self.config["dim"] // self.config["num_heads"]
kv_quant = self.ar_config.get("kv_quant")
kv_quant_sig = tuple(sorted((k, str(v)) for k, v in kv_quant.items())) if isinstance(kv_quant, dict) else bool(kv_quant)
buffer_sig = (
int(self.kv_size),
int(self.cache_num_heads),
int(self.config["num_layers"]),
int(head_dim),
str(self.dtype),
int(self.frame_seq_length),
kv_quant_sig,
bool(self.ar_config.get("kv_offload")),
self.config.get("infer_steps"),
)| if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
| sp_group = getattr(self.model, "seq_p_group", None) | ||
| if sp_group is not None or torch.distributed.get_world_size() > 1: | ||
| torch.distributed.barrier(group=sp_group) |
There was a problem hiding this comment.
Calling torch.distributed.barrier on the default process group (when sp_group is None but world_size > 1) can cause deadlocks in distributed environments that use other parallelisms (such as pipeline parallel) where not all ranks execute this runner. Since the barrier is only needed to synchronize sequence-parallel ranks when the KV cache is sharded, it should only be called when sp_group is explicitly not None.
Recommendation: Simplify the condition to only synchronize when sp_group is present.
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| sp_group = getattr(self.model, "seq_p_group", None) | |
| if sp_group is not None or torch.distributed.get_world_size() > 1: | |
| torch.distributed.barrier(group=sp_group) | |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| sp_group = getattr(self.model, "seq_p_group", None) | |
| if sp_group is not None: | |
| torch.distributed.barrier(group=sp_group) |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
| sp_group = getattr(self.model, "seq_p_group", None) | ||
| if sp_group is not None or torch.distributed.get_world_size() > 1: | ||
| torch.distributed.barrier(group=sp_group) |
There was a problem hiding this comment.
Calling torch.distributed.barrier on the default process group (when sp_group is None but world_size > 1) can cause deadlocks in distributed environments that use other parallelisms (such as pipeline parallel) where not all ranks execute this runner. Since the barrier is only needed to synchronize sequence-parallel ranks when the KV cache is sharded, it should only be called when sp_group is explicitly not None.
Recommendation: Simplify the condition to only synchronize when sp_group is present.
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| sp_group = getattr(self.model, "seq_p_group", None) | |
| if sp_group is not None or torch.distributed.get_world_size() > 1: | |
| torch.distributed.barrier(group=sp_group) | |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| sp_group = getattr(self.model, "seq_p_group", None) | |
| if sp_group is not None: | |
| torch.distributed.barrier(group=sp_group) |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
| sp_group = getattr(self.model, "seq_p_group", None) | ||
| if sp_group is not None or torch.distributed.get_world_size() > 1: | ||
| torch.distributed.barrier(group=sp_group) |
There was a problem hiding this comment.
Calling torch.distributed.barrier on the default process group (when sp_group is None but world_size > 1) can cause deadlocks in distributed environments that use other parallelisms (such as pipeline parallel) where not all ranks execute this runner. Since the barrier is only needed to synchronize sequence-parallel ranks when the KV cache is sharded, it should only be called when sp_group is explicitly not None.
Recommendation: Simplify the condition to only synchronize when sp_group is present.
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| sp_group = getattr(self.model, "seq_p_group", None) | |
| if sp_group is not None or torch.distributed.get_world_size() > 1: | |
| torch.distributed.barrier(group=sp_group) | |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| sp_group = getattr(self.model, "seq_p_group", None) | |
| if sp_group is not None: | |
| torch.distributed.barrier(group=sp_group) |
No description provided.