Skip to content

Optimize Seko AR KV cache#1109

Merged
gushiqiao merged 1 commit into
mainfrom
gsq/dev-seko-ar-1
Jun 1, 2026
Merged

Optimize Seko AR KV cache#1109
gushiqiao merged 1 commit into
mainfrom
gsq/dev-seko-ar-1

Conversation

@gushiqiao
Copy link
Copy Markdown
Contributor

No description provided.

@gushiqiao gushiqiao merged commit 741d9b5 into main Jun 1, 2026
2 checks passed
@gushiqiao gushiqiao deleted the gsq/dev-seko-ar-1 branch June 1, 2026 08:23
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces KV cache buffer reuse to avoid reallocation, adds a fused Triton kernel for KIVI dequantization, and refactors several model runners to reuse the cache manager. Feedback on these changes highlights critical issues: replacing torch.zeros with torch.empty and removing zero_() in reset() introduces risks of NaN propagation and sensitive data leakage (CWE-226) across sessions. Additionally, using a simple boolean check for kv_quant in the buffer signature can cause incorrect cache reuse when quantization configurations change, and calling torch.distributed.barrier without an explicit sequence parallel group can lead to deadlocks in distributed environments.

Comment on lines 50 to +51
def reset(self) -> None:
self._k_buffer.zero_()
self._v_buffer.zero_()
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

Removing zero_() in reset() and switching to torch.empty introduces two major risks:

  1. NaN Propagation / Correctness Issues: torch.empty allocates uninitialized memory which can contain arbitrary garbage, including NaN or Inf values. If the attention mechanism or any masking operation performs computations over the entire buffer (even if masked out by multiplying with 0, since 0 * NaN = NaN), these NaNs can propagate and corrupt the model's outputs.
  2. Data Leakage (Security Vulnerability - CWE-226): Since the KVCacheManager now reuses cached KV buffers across different inference requests (by matching buffer_sig), not zeroing out the buffers on reset() means that sensitive data from a previous request remains in memory. If a subsequent request (potentially from a different user) reads beyond its current sequence length due to a bug, speculative decoding, or different sequence lengths, it could leak information from the previous request.

Recommendation: Revert to using torch.zeros or ensure that reset() explicitly zeroes out the buffers (e.g., using zero_()) to maintain correctness and security.

Suggested change
def reset(self) -> None:
self._k_buffer.zero_()
self._v_buffer.zero_()
pass
def reset(self) -> None:
self._k_buffer.zero_()
self._v_buffer.zero_()

Comment on lines 383 to 392
def reset(self) -> None:
if not self._kv_offload:
self._k_buffer.zero_()
self._v_buffer.zero_()
self._global_end.zero_()
self._local_end.zero_()
self._init_ring_metadata()
return
self.sync_all()
self._k_cpu.zero_()
self._v_cpu.zero_()
self._k_gpu_buf.zero_()
self._v_gpu_buf.zero_()
self._global_end.zero_()
self._local_end.zero_()
self._init_ring_metadata()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

Removing zero_() in reset() and switching to torch.empty introduces two major risks:

  1. NaN Propagation / Correctness Issues: torch.empty allocates uninitialized memory which can contain arbitrary garbage, including NaN or Inf values. If the attention mechanism or any masking operation performs computations over the entire buffer (even if masked out by multiplying with 0, since 0 * NaN = NaN), these NaNs can propagate and corrupt the model's outputs.
  2. Data Leakage (Security Vulnerability - CWE-226): Since the KVCacheManager now reuses cached KV buffers across different inference requests (by matching buffer_sig), not zeroing out the buffers on reset() means that sensitive data from a previous request remains in memory. If a subsequent request (potentially from a different user) reads beyond its current sequence length due to a bug, speculative decoding, or different sequence lengths, it could leak information from the previous request.

Recommendation: Revert to using torch.zeros or ensure that reset() explicitly zeroes out the buffers (e.g., using zero_()) to maintain correctness and security.

Suggested change
def reset(self) -> None:
if not self._kv_offload:
self._k_buffer.zero_()
self._v_buffer.zero_()
self._global_end.zero_()
self._local_end.zero_()
self._init_ring_metadata()
return
self.sync_all()
self._k_cpu.zero_()
self._v_cpu.zero_()
self._k_gpu_buf.zero_()
self._v_gpu_buf.zero_()
self._global_end.zero_()
self._local_end.zero_()
self._init_ring_metadata()
def reset(self) -> None:
if not self._kv_offload:
self._k_buffer.zero_()
self._v_buffer.zero_()
self._global_end.zero_()
self._local_end.zero_()
self._init_ring_metadata()
return
self.sync_all()
self._k_cpu.zero_()
self._v_cpu.zero_()
self._k_gpu_buf.zero_()
self._v_gpu_buf.zero_()
self._global_end.zero_()
self._local_end.zero_()
self._init_ring_metadata()

Comment on lines +166 to +169
self._k_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory()
self._v_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory()
self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device)
self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.empty instead of torch.zeros can lead to silent correctness issues and NaN propagation. Uninitialized memory allocated by torch.empty can contain arbitrary garbage, including NaN or Inf values. If the attention mechanism or any masking operation performs computations over the entire buffer (even if masked out by multiplying with 0, since 0 * NaN = NaN), these NaNs can propagate and corrupt the model's outputs.

Recommendation: Revert to torch.zeros to ensure that the buffers are safely initialized.

Suggested change
self._k_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory()
self._v_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory()
self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device)
self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device)
self._k_cpu = torch.zeros(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory()
self._v_cpu = torch.zeros(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory()
self._k_gpu_buf = torch.zeros(N, H, D, dtype=self._dtype, device=self._device)
self._v_gpu_buf = torch.zeros(N, H, D, dtype=self._dtype, device=self._device)

Comment on lines +268 to +279
head_dim = self.config["dim"] // self.config["num_heads"]
buffer_sig = (
int(self.kv_size),
int(self.cache_num_heads),
int(self.config["num_layers"]),
int(head_dim),
str(self.dtype),
int(self.frame_seq_length),
bool(self.ar_config.get("kv_quant")),
bool(self.ar_config.get("kv_offload")),
self.config.get("infer_steps"),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using bool(self.ar_config.get("kv_quant")) in buffer_sig is highly risky and can lead to crashes or silent data corruption.

If kv_quant is a dictionary (e.g., {"quant_scheme": "kivi"}), bool(kv_quant) evaluates to True. If a user changes the quantization settings between runs (for example, switching quant_scheme from "kivi" to "longlive_fp4", or changing the number of bits/group size), bool(kv_quant) remains True. As a result, buffer_sig will match, and the manager will reuse the existing self_attn_kv_cache instance of the wrong class or configuration, leading to AttributeError or severe runtime corruption when incompatible quantization kernels are executed.

Recommendation: Include a hashable representation of the kv_quant configuration in buffer_sig instead of just its boolean value.

        head_dim = self.config["dim"] // self.config["num_heads"]
        kv_quant = self.ar_config.get("kv_quant")
        kv_quant_sig = tuple(sorted((k, str(v)) for k, v in kv_quant.items())) if isinstance(kv_quant, dict) else bool(kv_quant)
        buffer_sig = (
            int(self.kv_size),
            int(self.cache_num_heads),
            int(self.config["num_layers"]),
            int(head_dim),
            str(self.dtype),
            int(self.frame_seq_length),
            kv_quant_sig,
            bool(self.ar_config.get("kv_offload")),
            self.config.get("infer_steps"),
        )

Comment on lines +1191 to +1194
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None or torch.distributed.get_world_size() > 1:
torch.distributed.barrier(group=sp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.distributed.barrier on the default process group (when sp_group is None but world_size > 1) can cause deadlocks in distributed environments that use other parallelisms (such as pipeline parallel) where not all ranks execute this runner. Since the barrier is only needed to synchronize sequence-parallel ranks when the KV cache is sharded, it should only be called when sp_group is explicitly not None.

Recommendation: Simplify the condition to only synchronize when sp_group is present.

Suggested change
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None or torch.distributed.get_world_size() > 1:
torch.distributed.barrier(group=sp_group)
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None:
torch.distributed.barrier(group=sp_group)

Comment on lines +70 to +73
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None or torch.distributed.get_world_size() > 1:
torch.distributed.barrier(group=sp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.distributed.barrier on the default process group (when sp_group is None but world_size > 1) can cause deadlocks in distributed environments that use other parallelisms (such as pipeline parallel) where not all ranks execute this runner. Since the barrier is only needed to synchronize sequence-parallel ranks when the KV cache is sharded, it should only be called when sp_group is explicitly not None.

Recommendation: Simplify the condition to only synchronize when sp_group is present.

Suggested change
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None or torch.distributed.get_world_size() > 1:
torch.distributed.barrier(group=sp_group)
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None:
torch.distributed.barrier(group=sp_group)

Comment on lines +55 to +58
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None or torch.distributed.get_world_size() > 1:
torch.distributed.barrier(group=sp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.distributed.barrier on the default process group (when sp_group is None but world_size > 1) can cause deadlocks in distributed environments that use other parallelisms (such as pipeline parallel) where not all ranks execute this runner. Since the barrier is only needed to synchronize sequence-parallel ranks when the KV cache is sharded, it should only be called when sp_group is explicitly not None.

Recommendation: Simplify the condition to only synchronize when sp_group is present.

Suggested change
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None or torch.distributed.get_world_size() > 1:
torch.distributed.barrier(group=sp_group)
if torch.distributed.is_available() and torch.distributed.is_initialized():
sp_group = getattr(self.model, "seq_p_group", None)
if sp_group is not None:
torch.distributed.barrier(group=sp_group)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants