Skip to content

Add MLA support to vLLM patch#1499

Open
rohansjoshi wants to merge 6 commits into
mainfrom
rohjoshi/sparse-kernel
Open

Add MLA support to vLLM patch#1499
rohansjoshi wants to merge 6 commits into
mainfrom
rohjoshi/sparse-kernel

Conversation

@rohansjoshi
Copy link
Copy Markdown
Contributor

@rohansjoshi rohansjoshi commented May 15, 2026

What does this PR do?

Builds on vLLM integration PR :

  • Adds support for threshold scale factors (instead of fixed thresholds) in skip softmax Triton kernel
  • Adds MLA support in vLLM patch

Type of change: new feature

Usage

SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python
examples/vllm_serve/vllm_serve_sparse_attn.py \
    nvidia/DeepSeek-R1-NVFP4

Summary by CodeRabbit

  • New Features

    • Added vLLM server integration with sparse attention support, including example scripts and worker implementations for efficient inference.
    • Introduced paged KV-cache support for Triton flash-attention kernels.
    • Added skip-softmax scale-factor threshold mode for fine-grained sparsity control.
  • Dependencies

    • Added uvloop and vllm==0.20.1 dependencies.
  • Tests

    • Added comprehensive test coverage for paged KV-cache functionality and skip-softmax scale-factor behavior.

Review Change Stack

kaix-nv and others added 5 commits May 15, 2026 18:05
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
@rohansjoshi rohansjoshi requested review from a team as code owners May 15, 2026 18:19
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 15, 2026

📝 Walkthrough

Walkthrough

This pull request adds comprehensive sparse-attention support for vLLM by extending Triton flash-attention kernels with paged KV-cache and per-sequence skip-softmax thresholds, implementing a vLLM sparse-attention plugin, providing worker classes for model patching, and delivering a server launcher with extensive test coverage and configuration updates.

Changes

vLLM Sparse Attention Infrastructure

Layer / File(s) Summary
Triton flash-attention paged KV-cache support
modelopt/torch/kernels/common/attention/triton_fa.py
Adds _load_paged_k_tile and _load_paged_v_tile Triton helpers to load K/V tiles from paged cache via block tables, extends forward kernel signature with sm_scale and paged-KV constexpr parameters, and updates K/V tiling loops to conditionally load from paged cache.
Skip-softmax scale-factor mode
modelopt/torch/kernels/common/attention/triton_fa.py, modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
Adds USE_SKIP_SCALE_FACTOR and SKIP_THRESHOLD_SCALE_LOG2 constexpr controls to forward/backward kernels; derives per-program skip_threshold_log2 inside kernels by subtracting log2(seq_len_kv) * sm_scale from scale-factor parameter or using fixed threshold. Updates _skip_softmax_decision to accept runtime skip_threshold_log2 instead of compile-time constant.
Autograd wrapper and public API
modelopt/torch/kernels/common/attention/triton_fa.py
Updates _Attention.forward to accept scale-factor and paged-KV parameters, determines is_paged, derives mode-dependent values, propagates to kernel launches, and adjusts backward return tuple. Changes attention(...) signature to remove skip_softmax_raw_threshold and measure_sparsity, add scale-factor thresholds and paged-KV arguments. Adds exported attention_with_lse(...) helper returning (output, lse) in natural-log space.
vLLM sparse attention plugin
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
Implements ModelOptSparseAttentionImpl to override forward() and invoke Triton sparse kernel with paged KV metadata and sparse_kw parameters, logs one-time-per-config warnings. Adds ModelOptSparseAttentionBackend registering backend name MODELOPT_SPARSE. Provides MLA monkey-patch helpers (_modelopt_mla_run_prefill_new_tokens, _modelopt_mla_run_prefill_context_chunk) calling attention_with_lse and padding V when needed.
vLLM sparse worker classes
examples/vllm_serve/sparse_attn_worker.py
Defines SparseAttnWorker and SparseQuantWorker extending BaseWorker; both invoke _replace_attention_impl during load_model to swap Attention layers with sparse implementations or monkey-patch MLA. SparseQuantWorker additionally runs quantization prolog during compile_or_warm_up_model. Helper functions build sparse config from environment variables (preset or calibration JSON), match module names via fnmatch, and construct per-layer sparse-kernel kwargs with optional skip_softmax_threshold override.
vLLM server launcher
examples/vllm_serve/vllm_serve_sparse_attn.py
Detects vLLM version to import correct FlexibleArgumentParser and RayDistributedExecutor, builds sparse/quantization environment-variable list, propagates to Ray workers when available. main() configures CLI parsing, amends sys.path/PYTHONPATH for local worker import, selects worker class based on SPARSE_* and QUANT_* env-vars (with fallback warning), and launches server via uvloop.run(run_server(args)).
Paged KV-cache tests
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
Adds _scatter_to_paged_cache helper and TestPagedKV test class validating paged output matches contiguous reference, contains no NaNs/Infs, handles variable-length sequences, works across multiple page sizes, functions with N:M sparsity, and supports decode scenarios.
Skip-softmax scale-factor tests
tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
Adds TestSkipSoftmaxScaleFactor class validating: disabling scale factors reproduces dense output, scale-factor mode equals fixed-threshold under uniform sequence length, decode-phase uses only decode scale-factor, prefill-phase uses only prefill scale-factor, fixed threshold and scale-factor mode are mutually exclusive, and scale-factor path matches PyTorch reference.
Configuration and dependencies
modelopt/torch/sparsity/attention_sparsity/config.py, pyproject.toml, modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py, modelopt/torch/quantization/plugins/vllm.py
Updates skip-softmax thresholds from 0.1 to 0.001 in SKIP_SOFTMAX_TRITON_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT; adds uvloop>=0.22.1 and vllm==0.20.1 dependencies; updates Ruff isort to classify vllm as third-party; adds import_plugin import; refactors vLLM quantization plugin to use new _vllm_module_spec_exists helper for safer module probing.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add MLA support to vLLM patch' accurately describes the primary feature addition in the changeset—MLA (multi-level attention) support for vLLM sparse attention.
Docstring Coverage ✅ Passed Docstring coverage is 87.80% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found in PR. Complies with SECURITY.md: no unsafe deserialization, hardcoded trust_remote_code, eval/exec risks, or nosec bypasses. New dependencies have permissive licenses.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch rohjoshi/sparse-kernel

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@rohansjoshi rohansjoshi requested review from kaix-nv and removed request for realAsma May 15, 2026 18:20
@rohansjoshi rohansjoshi changed the title Rohjoshi/sparse kernel Add MLA support to vLLM patch May 15, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/kernels/common/attention/triton_fa.py (2)

894-953: ⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy lift

Resolve merge conflict in skip-softmax threshold computation.

Two different threshold configuration models:

  • HEAD: skip_softmax_raw_threshold (raw kernel value) vs skip_softmax_threshold (lambda)
  • Incoming: mutually exclusive fixed mode vs scale-factor mode with prefill/decode phase detection

Decide on the threshold configuration API and resolve accordingly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/common/attention/triton_fa.py` around lines 894 - 953,
Resolve the merge conflict by unifying both configurations: support the raw
kernel override (skip_softmax_raw_threshold) while keeping the
mutually-exclusive fixed vs scale-factor API (skip_softmax_threshold and
skip_softmax_threshold_scale_prefill/decode). Implement logic that first checks
skip_softmax_raw_threshold (if not None) and sets apply_skip=True and
skip_threshold_log2=skip_softmax_raw_threshold and
skip_threshold_scale_log2=0.0; otherwise determine fixed_mode,
scale_prefill_set, scale_decode_set and assert mutual exclusivity, compute
is_decode from max_input_len, set active_scale and use_scale_factor, then set
apply_skip, skip_threshold_log2 (math.log2(skip_softmax_threshold)*sm_scale for
fixed_mode or 0.0 for scale-factor), and skip_threshold_scale_log2 (0.0 for
fixed_mode or math.log2(active_scale)*sm_scale for scale-factor), with the
fallback of apply_skip=False and both skip_threshold_* set to 0.0.

370-406: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Resolve merge conflict in skip-softmax implementation.

Two different implementations of skip-softmax tile-skipping logic are present:

  • HEAD: delegates to _skip_softmax_decision helper with sparsity counters
  • Incoming: inline BLASST-style tile skipping with scale-factor support

Choose one implementation or merge both features if both sparsity measurement and scale-factor mode are needed.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/common/attention/triton_fa.py` around lines 370 - 406,
There is a merge conflict between two skip-softmax implementations: one that
calls the helper _skip_softmax_decision(scores, row_max, SKIP_THRESHOLD_LOG2,
Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY) and an inline BLASST-style
implementation using tile_row_max, can_skip, skip_threshold_log2 and
USE_SKIP_SCALE_FACTOR; pick one coherent approach or merge them by keeping the
BLASST inline logic (tile_row_max, can_skip, skip_tile) and then update the
sparsity counters (Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY) in the
same block so metrics remain collected, or replace the inline block with a
single call to _skip_softmax_decision and ensure the threshold symbol is
consistent (SKIP_THRESHOLD_LOG2 vs skip_threshold_log2) and
USE_SKIP_SCALE_FACTOR behavior is respected; update only the code paths
referenced by _skip_softmax_decision or the inline variables (scores, row_max,
tile_row_max, skip_tile, SKIP_THRESHOLD_LOG2/skip_threshold_log2,
Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY, USE_SKIP_SCALE_FACTOR) so
the final file contains one unambiguous implementation and consistent variable
names.
🧹 Nitpick comments (2)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py (1)

290-337: ⚡ Quick win

Strengthen decode test with correctness parity (not only finiteness).

Current decode coverage checks shape/NaN, but not correctness against a contiguous reference, so silent numerical regressions can slip through. Add a contiguous decode baseline and assert closeness; also add an Inf check for consistency with other tests.

Suggested patch
@@
     def test_paged_decode(self):
@@
-        out = attention(
+        out_contig = attention(
+            q_flat,
+            k_flat,
+            v_flat,
+            b_start_loc_q,
+            b_seq_len_q,
+            1,
+            is_causal=False,
+            softmax_scale=scale,
+            b_start_loc_k=b_start_loc_k,
+            b_seq_len_k=b_seq_len_k,
+            max_input_len_k=max(seq_lens_k),
+        )
+
+        out = attention(
             q_flat,
             k_flat,
             v_flat,
@@
         )
 
+        torch.testing.assert_close(out, out_contig, rtol=1e-2, atol=1e-2)
         assert out.shape == q_flat.shape
         assert not torch.isnan(out).any(), "NaN in paged decode output"
+        assert not torch.isinf(out).any(), "Inf in paged decode output"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py` around
lines 290 - 337, test_paged_decode only checks shape/NaN but not numeric parity
with a contiguous decode; call the same attention(...) without the paged cache
(pass k_flat/v_flat and omit k_cache/v_cache/block_table/page_size, using the
same q_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k and
softmax_scale) to produce a baseline output, then assert the paged output from
attention(...) is close to the baseline (use torch.allclose or torch.isclose
with reasonable rtol/atol) and also assert torch.isfinite on both outputs;
update the test_paged_decode function (references: test_paged_decode, attention,
_scatter_to_paged_cache, q_flat, k_flat, v_flat, k_cache, v_cache, block_table)
to include these comparisons.
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py (1)

107-115: 💤 Low value

Consider consolidating logging output.

The message is emitted via both logger.warning() and print(). Since the logger should handle output routing, the explicit print() may be redundant unless there's a specific need to ensure console visibility independent of logging configuration.

♻️ Simplify by removing print
             logger.warning(msg)
-            print(msg, flush=True)
             _LOGGED_CONFIGS.add(kw_key)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py` around lines 107
- 115, The code logs the sparse Triton kernel message twice by calling
logger.warning(msg) and print(msg); remove the redundant print() call so all
output is routed through the configured logger (leave logger.warning(msg) and
the subsequent _LOGGED_CONFIGS.add(kw_key) intact), referencing the same
variables used in the message construction (variant_str, is_prefill, layer,
sparse_kw, kw_key) to locate the spot in vllm.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 148-152: The check that sets sparse_kw["skip_softmax_threshold"]
incorrectly treats 0.0 as falsy; change the conditional that currently uses "if
threshold:" to explicitly test for non-None (e.g., "if threshold is not None:")
so that a valid zero threshold from layer_cfg or env_threshold is preserved;
update the block that reads threshold from layer_cfg and env_threshold
(variables: threshold, layer_cfg, env_threshold) and assigns to sparse_kw to use
this non-None check.
- Around line 107-108: Replace the bare file open used to load calib_cfg with a
robust read that specifies encoding and handles errors: change the block using
"with open(path) as f: calib_cfg = json.load(f)" to "with open(path, 'r',
encoding='utf-8') as f: ..." and wrap it in a try/except that catches
FileNotFoundError, json.JSONDecodeError, and generic OSError to either log a
clear error (including the path and exception) or re-raise a descriptive
exception; ensure the variable name calib_cfg and the path variable remain the
same so callers are unaffected.

In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Around line 78-81: The PYTHONPATH concatenation currently always prepends ":"
which yields a leading empty path entry when PYTHONPATH is empty; update the
logic around repo_root/sys.path.insert and os.environ["PYTHONPATH"] to read the
existing value (os.environ.get("PYTHONPATH")), and set os.environ["PYTHONPATH"]
to either "<existing> + os.pathsep + repo_root" if existing is non-empty or just
"repo_root" if not, using os.pathsep instead of a hardcoded ":" to be portable
and avoid introducing an empty path entry.

In `@modelopt/torch/kernels/__init__.py`:
- Around line 16-55: Remove the leftover merge conflict markers and keep the
completed initialization logic from the incoming change: use the descriptive
module docstring, import torch and import_plugin, define IS_AVAILABLE=False and
attention/attention_with_lse/register_triton_attention=None, then if
torch.cuda.is_available() use import_plugin("triton") to import .triton_fa
attention and attention_with_lse, set IS_AVAILABLE=True, and inside an
import_plugin("transformers") block import and assign register_triton_attention
from .hf_triton_attention; finally ensure __all__ contains "IS_AVAILABLE",
"attention", "attention_with_lse", and "register_triton_attention".

In `@modelopt/torch/kernels/common/attention/triton_fa.py`:
- Around line 1323-1325: There is a merge conflict in the module exports:
reconcile the __all__ to export the correct public API by including the
functions consumers need—ensure LOG2E and _apply_mask are exported only if they
are intended public symbols, and add the new API attention_with_lse if the
incoming branch introduced it; update the __all__ to something like ["LOG2E",
"_apply_mask", "attention", "attention_with_lse"] (omit LOG2E/_apply_mask if
they are internal), and keep the symbol names exactly as defined in this file
(LOG2E, _apply_mask, attention, attention_with_lse) so imports work correctly.
- Around line 1309-1320: The call to _Attention.apply has leftover
merge-conflict argument names; update the argument list to match the resolved
forward signature by removing skip_softmax_raw_threshold and measure_sparsity
and passing skip_softmax_threshold_scale_prefill and
skip_softmax_threshold_scale_decode in their place (preserving the order used by
_Attention.apply), and ensure other arguments like k_cache, v_cache,
block_table, and page_size remain unchanged so the call matches the
implementation signature.
- Around line 1005-1012: The merge left conflicting argument names in the
_attn_fwd kernel call: remove the leftover
Sparsity_total/Sparsity_skipped/MEASURE_SPARSITY branch and replace it with the
resolved kernel parameters USE_SKIP_SCALE_FACTOR and SKIP_THRESHOLD_SCALE_LOG2
(passing the local variables use_scale_factor and skip_threshold_scale_log2
respectively) so the _attn_fwd invocation matches the final kernel signature;
ensure no duplicate or undefined symbols remain and that the argument
order/keywords align with the kernel's definition.
- Around line 1182-1193: The backward gradient tuple in triton_fa.py must match
the resolved forward parameter list: replace the two outdated placeholders
(skip_softmax_raw_threshold, measure_sparsity) with the correct entries for
skip_softmax_threshold_scale_prefill and skip_softmax_threshold_scale_decode and
ensure the ordering matches the forward signature before the existing k_cache,
v_cache, block_table, page_size placeholders so the number and order of None
entries align with the forward method.
- Around line 238-247: The _attn_fwd kernel signature contains git conflict
markers and missing parameters; remove the conflict markers and merge both
branches so the signature includes SKIP_THRESHOLD_LOG2: tl.constexpr,
USE_SKIP_SCALE_FACTOR: tl.constexpr, SKIP_THRESHOLD_SCALE_LOG2: tl.constexpr,
plus the runtime sparsity measurement params Sparsity_total, Sparsity_skipped
and MEASURE_SPARSITY (keeping the existing types/comments), and update the
inline comments so SKIP_THRESHOLD_LOG2 is documented as the fixed-mode
pre-scaled log2(lambda) and SKIP_THRESHOLD_SCALE_LOG2/USE_SKIP_SCALE_FACTOR are
documented for scale-factor mode; ensure no duplicate names and that the
parameter order matches callers of _attn_fwd.
- Around line 856-862: The _Attention.forward signature contains leftover
merge-conflict parameters; update the function signature to match the resolved
kernel signature (the parameters shown around lines 238-247 in the diff) by
removing the obsolete names (e.g., skip_softmax_raw_threshold, measure_sparsity)
and replacing them with the finalized parameters (e.g.,
skip_softmax_threshold_scale_prefill, skip_softmax_threshold_scale_decode) so
the forward method parameters exactly mirror the resolved triton kernel
signature.
- Around line 1214-1220: The public attention() signature contains unresolved
merge markers: remove the conflict markers and settle on a single API; choose
whether to keep skip_softmax_raw_threshold or the new pair
skip_softmax_threshold_scale_prefill and skip_softmax_threshold_scale_decode (do
not leave both variants). For minimal breakage prefer the new explicit pair plus
retain measure_sparsity if it was intended: update the attention(...) parameter
list to include skip_softmax_threshold_scale_prefill: float | None = None,
skip_softmax_threshold_scale_decode: float | None = None, measure_sparsity: bool
= False (and remove skip_softmax_raw_threshold and all <<<<>>> markers), then
update the attention() docstring and any exports/tests that reference the old
name to match the chosen signature.
- Around line 1409-1454: The _attn_fwd kernel call is missing the sm_scale
argument expected immediately after qk_scale; update the invocation of _attn_fwd
to pass the sm_scale parameter (the same value/type used elsewhere where the
kernel is called or from the local variable named sm_scale) between qk_scale and
b_start_loc so the argument order matches the kernel signature used by _attn_fwd
and avoids parameter misalignment.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 728-732: The __all__ export list contains unresolved merge markers
and is currently choosing between VSA_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT;
remove the conflict markers and export the correct names — include both
VSA_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT in the __all__ list (no merge
markers, no duplicates) so both symbols are publicly exported from the module.

In `@pyproject.toml`:
- Around line 52-53: Remove "uvloop>=0.22.1" from the core pyproject.toml
dependencies and add it to the examples/vllm_serve/requirements.txt since uvloop
is only used by the example scripts (refer to examples/vllm_serve/). For
"vllm==0.20.1", move it out of core dependencies into
[project.optional-dependencies] (or an extras group) and document that it is
required only for the vLLM quantization plugin invoked via
import_plugin("vllm"); also relax the pin to a range such as ~=0.20.1 or
>=0.20.1,<0.21 to avoid a fragile exact-match dependency.

---

Outside diff comments:
In `@modelopt/torch/kernels/common/attention/triton_fa.py`:
- Around line 894-953: Resolve the merge conflict by unifying both
configurations: support the raw kernel override (skip_softmax_raw_threshold)
while keeping the mutually-exclusive fixed vs scale-factor API
(skip_softmax_threshold and skip_softmax_threshold_scale_prefill/decode).
Implement logic that first checks skip_softmax_raw_threshold (if not None) and
sets apply_skip=True and skip_threshold_log2=skip_softmax_raw_threshold and
skip_threshold_scale_log2=0.0; otherwise determine fixed_mode,
scale_prefill_set, scale_decode_set and assert mutual exclusivity, compute
is_decode from max_input_len, set active_scale and use_scale_factor, then set
apply_skip, skip_threshold_log2 (math.log2(skip_softmax_threshold)*sm_scale for
fixed_mode or 0.0 for scale-factor), and skip_threshold_scale_log2 (0.0 for
fixed_mode or math.log2(active_scale)*sm_scale for scale-factor), with the
fallback of apply_skip=False and both skip_threshold_* set to 0.0.
- Around line 370-406: There is a merge conflict between two skip-softmax
implementations: one that calls the helper _skip_softmax_decision(scores,
row_max, SKIP_THRESHOLD_LOG2, Sparsity_total, Sparsity_skipped,
MEASURE_SPARSITY) and an inline BLASST-style implementation using tile_row_max,
can_skip, skip_threshold_log2 and USE_SKIP_SCALE_FACTOR; pick one coherent
approach or merge them by keeping the BLASST inline logic (tile_row_max,
can_skip, skip_tile) and then update the sparsity counters (Sparsity_total,
Sparsity_skipped, MEASURE_SPARSITY) in the same block so metrics remain
collected, or replace the inline block with a single call to
_skip_softmax_decision and ensure the threshold symbol is consistent
(SKIP_THRESHOLD_LOG2 vs skip_threshold_log2) and USE_SKIP_SCALE_FACTOR behavior
is respected; update only the code paths referenced by _skip_softmax_decision or
the inline variables (scores, row_max, tile_row_max, skip_tile,
SKIP_THRESHOLD_LOG2/skip_threshold_log2, Sparsity_total, Sparsity_skipped,
MEASURE_SPARSITY, USE_SKIP_SCALE_FACTOR) so the final file contains one
unambiguous implementation and consistent variable names.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 107-115: The code logs the sparse Triton kernel message twice by
calling logger.warning(msg) and print(msg); remove the redundant print() call so
all output is routed through the configured logger (leave logger.warning(msg)
and the subsequent _LOGGED_CONFIGS.add(kw_key) intact), referencing the same
variables used in the message construction (variant_str, is_prefill, layer,
sparse_kw, kw_key) to locate the spot in vllm.py.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py`:
- Around line 290-337: test_paged_decode only checks shape/NaN but not numeric
parity with a contiguous decode; call the same attention(...) without the paged
cache (pass k_flat/v_flat and omit k_cache/v_cache/block_table/page_size, using
the same q_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k and
softmax_scale) to produce a baseline output, then assert the paged output from
attention(...) is close to the baseline (use torch.allclose or torch.isclose
with reasonable rtol/atol) and also assert torch.isfinite on both outputs;
update the test_paged_decode function (references: test_paged_decode, attention,
_scatter_to_paged_cache, q_flat, k_flat, v_flat, k_cache, v_cache, block_table)
to include these comparisons.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a1b4948d-004e-4822-9b51-e4f09740b68c

📥 Commits

Reviewing files that changed from the base of the PR and between a451a2b and a90c9d6.

📒 Files selected for processing (11)
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/common/attention/triton_fa.py
  • modelopt/torch/quantization/plugins/vllm.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • pyproject.toml
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py

Comment on lines +107 to +108
with open(path) as f:
calib_cfg = json.load(f)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add error handling and specify encoding for file operations.

The file open operation lacks error handling and doesn't specify an encoding, which can cause issues on systems with non-UTF-8 defaults.

🛡️ Proposed fix
 def _load_sparse_config(path: str) -> dict:
     """Load offline calibration config JSON."""
-    with open(path) as f:
-        calib_cfg = json.load(f)
+    try:
+        with open(path, encoding="utf-8") as f:
+            calib_cfg = json.load(f)
+    except (FileNotFoundError, json.JSONDecodeError) as e:
+        raise ValueError(f"Failed to load sparse config from {path}: {e}") from e
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 107 - 108, Replace
the bare file open used to load calib_cfg with a robust read that specifies
encoding and handles errors: change the block using "with open(path) as f:
calib_cfg = json.load(f)" to "with open(path, 'r', encoding='utf-8') as f: ..."
and wrap it in a try/except that catches FileNotFoundError,
json.JSONDecodeError, and generic OSError to either log a clear error (including
the path and exception) or re-raise a descriptive exception; ensure the variable
name calib_cfg and the path variable remain the same so callers are unaffected.

Comment on lines +148 to +152
threshold = layer_cfg.get("skip_softmax_threshold")
if env_threshold is not None:
threshold = env_threshold
if threshold:
sparse_kw["skip_softmax_threshold"] = float(threshold)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix threshold validation to handle zero values.

The condition if threshold: evaluates to False when threshold=0.0, which would skip a valid zero threshold. Use if threshold is not None: instead.

🐛 Proposed fix
     threshold = layer_cfg.get("skip_softmax_threshold")
     if env_threshold is not None:
         threshold = env_threshold
-    if threshold:
+    if threshold is not None:
         sparse_kw["skip_softmax_threshold"] = float(threshold)
     return sparse_kw
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 148 - 152, The check
that sets sparse_kw["skip_softmax_threshold"] incorrectly treats 0.0 as falsy;
change the conditional that currently uses "if threshold:" to explicitly test
for non-None (e.g., "if threshold is not None:") so that a valid zero threshold
from layer_cfg or env_threshold is preserved; update the block that reads
threshold from layer_cfg and env_threshold (variables: threshold, layer_cfg,
env_threshold) and assigns to sparse_kw to use this non-None check.

Comment on lines +78 to +81
repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix PYTHONPATH concatenation to avoid empty path entry.

When PYTHONPATH is empty or unset, the current code creates a leading : which adds an empty entry to the path, potentially causing Python to search the current directory. Use conditional logic to handle the empty case.

🐛 Proposed fix
     repo_root = str(Path(__file__).resolve().parent)
     if repo_root not in sys.path:
         sys.path.insert(0, repo_root)
-    os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"
+    existing = os.environ.get("PYTHONPATH", "")
+    os.environ["PYTHONPATH"] = f"{repo_root}:{existing}" if existing else repo_root
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"
repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
existing = os.environ.get("PYTHONPATH", "")
os.environ["PYTHONPATH"] = f"{repo_root}:{existing}" if existing else repo_root
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 78 - 81, The
PYTHONPATH concatenation currently always prepends ":" which yields a leading
empty path entry when PYTHONPATH is empty; update the logic around
repo_root/sys.path.insert and os.environ["PYTHONPATH"] to read the existing
value (os.environ.get("PYTHONPATH")), and set os.environ["PYTHONPATH"] to either
"<existing> + os.pathsep + repo_root" if existing is non-empty or just
"repo_root" if not, using os.pathsep instead of a hardcoded ":" to be portable
and avoid introducing an empty path entry.

Comment thread modelopt/torch/kernels/__init__.py Outdated
Comment on lines +16 to +55
<<<<<<< HEAD
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
=======
"""Shared Triton kernels for modelopt (attention, quantization, etc.)."""

import torch

from modelopt.torch.utils import import_plugin

IS_AVAILABLE = False
attention = None
attention_with_lse = None
register_triton_attention = None

if torch.cuda.is_available():
with import_plugin(
"triton",
msg_if_missing=(
"Your device is potentially capable of using the triton attention "
"kernel. Try to install triton with `pip install triton`."
),
):
from .triton_fa import attention as _attention
from .triton_fa import attention_with_lse as _attention_with_lse

attention = _attention
attention_with_lse = _attention_with_lse
IS_AVAILABLE = True
with import_plugin("transformers"):
from .hf_triton_attention import register_triton_attention as _register_triton_attention

register_triton_attention = _register_triton_attention

__all__ = [
"IS_AVAILABLE",
"attention",
"attention_with_lse",
"register_triton_attention",
]
>>>>>>> 5ea4c609 (Added MLA support)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Resolve merge conflict in module initialization.

The module has unresolved merge-conflict markers at line 16. The HEAD version has a different docstring, while the incoming version introduces conditional CUDA-gated plugin loading. Resolve the conflict to finalize the module structure.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/__init__.py` around lines 16 - 55, Remove the leftover
merge conflict markers and keep the completed initialization logic from the
incoming change: use the descriptive module docstring, import torch and
import_plugin, define IS_AVAILABLE=False and
attention/attention_with_lse/register_triton_attention=None, then if
torch.cuda.is_available() use import_plugin("triton") to import .triton_fa
attention and attention_with_lse, set IS_AVAILABLE=True, and inside an
import_plugin("transformers") block import and assign register_triton_attention
from .hf_triton_attention; finally ensure __all__ contains "IS_AVAILABLE",
"attention", "attention_with_lse", and "register_triton_attention".

Comment thread modelopt/torch/kernels/common/attention/triton_fa.py Outdated
Comment thread modelopt/torch/kernels/common/attention/triton_fa.py Outdated
Comment on lines +1323 to +1325
<<<<<<< HEAD:modelopt/torch/kernels/common/attention/triton_fa.py
__all__ = ["LOG2E", "_apply_mask", "attention"]
=======
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Resolve merge conflict in __all__ exports.

Determine the final export list. The incoming branch adds attention_with_lse (new API returning LSE). Verify whether LOG2E and _apply_mask from HEAD are needed for external consumers.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/common/attention/triton_fa.py` around lines 1323 -
1325, There is a merge conflict in the module exports: reconcile the __all__ to
export the correct public API by including the functions consumers need—ensure
LOG2E and _apply_mask are exported only if they are intended public symbols, and
add the new API attention_with_lse if the incoming branch introduced it; update
the __all__ to something like ["LOG2E", "_apply_mask", "attention",
"attention_with_lse"] (omit LOG2E/_apply_mask if they are internal), and keep
the symbol names exactly as defined in this file (LOG2E, _apply_mask, attention,
attention_with_lse) so imports work correctly.

Comment thread modelopt/torch/kernels/common/attention/triton_fa.py
Comment on lines +728 to +732
<<<<<<< HEAD
"VSA_DEFAULT",
=======
"SPARSE_SOFTMAX_SKIP_DEFAULT",
>>>>>>> 48eb4ea2 (First commit)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Resolve merge conflict in __all__ export list.

The __all__ list has unresolved merge-conflict markers. HEAD exports VSA_DEFAULT, while incoming exports SPARSE_SOFTMAX_SKIP_DEFAULT. Determine whether both should be exported or choose one.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 728 - 732,
The __all__ export list contains unresolved merge markers and is currently
choosing between VSA_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT; remove the
conflict markers and export the correct names — include both VSA_DEFAULT and
SPARSE_SOFTMAX_SKIP_DEFAULT in the __all__ list (no merge markers, no
duplicates) so both symbols are publicly exported from the module.

Comment thread pyproject.toml
Comment on lines +52 to +53
"uvloop>=0.22.1",
"vllm==0.20.1",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if vllm/uvloop imports are optional or required at modelopt import time
rg -n -C3 'import vllm|from vllm|import uvloop|from uvloop' \
  --type=py \
  --glob '!examples/**' \
  --glob '!tests/**' \
  modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 3840


🏁 Script executed:

# Check current dependency structure in pyproject.toml
cat pyproject.toml | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4022


🏁 Script executed:

# Check if vllm/uvloop are imported unconditionally at package initialization
find modelopt -name "__init__.py" -type f | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 786


🏁 Script executed:

# Search for uvloop usage
rg -n 'uvloop' --type=py modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check how plugins are loaded - look for plugin registration/import logic
rg -n 'plugin|import.*vllm' modelopt/ -A2 -B2 | grep -E '(def|class|import|plugin)' | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 2799


🏁 Script executed:

# Check if vllm plugins are imported unconditionally in __init__.py files
cat modelopt/torch/quantization/plugins/__init__.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2617


🏁 Script executed:

# Check sparsity plugins init
find modelopt -path "*sparsity*plugins*__init__.py" -type f -exec cat {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 2287


🏁 Script executed:

# More thorough search for uvloop - check all Python files
rg -n 'uvloop' . --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 362


🏁 Script executed:

# Check torch module __init__ to see what is imported at package init
head -150 modelopt/torch/__init__.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2375


🏁 Script executed:

# Check import_plugin implementation to understand gating mechanism
rg -n 'def import_plugin' modelopt/ -A15

Repository: NVIDIA/Model-Optimizer

Length of output: 1307


🏁 Script executed:

# Check if sparsity/attention_sparsity/plugins/vllm.py is imported anywhere
rg -n 'attention_sparsity.*vllm|from.*attention_sparsity.*plugins' modelopt/ -B2 -A2

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if sparsity has its own vllm plugin import
find modelopt/torch/sparsity -name "__init__.py" -exec grep -l "vllm\|import_plugin" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 189


🏁 Script executed:

# Check if sparsity/attention_sparsity/plugins/vllm.py is actually imported/used
cat modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1343


🏁 Script executed:

# Double-check: Is there any import of sparsity.attention_sparsity.plugins.vllm anywhere?
rg -n 'sparsity.attention_sparsity.plugins.vllm|from.*sparsity.*attention.*vllm' .

Repository: NVIDIA/Model-Optimizer

Length of output: 183


🏁 Script executed:

# Check the sparsity module __init__.py to see what's imported
cat modelopt/torch/sparsity/__init__.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1112


Move uvloop to example-specific requirements and consider making vllm optional.

Per coding guidelines, uvloop is used only in examples/vllm_serve/ scripts and should be listed in examples/vllm_serve/requirements.txt, not in core pyproject.toml dependencies. It is not imported anywhere in modelopt core code.

For vllm: While it is used in quantization plugins (gated via import_plugin("vllm")), it's only accessed by code explicitly importing the vllm plugin. Consider moving to [project.optional-dependencies] with a note that it's required for vLLM sparse-attention integration. Additionally, the exact version pin (==0.20.1) is fragile for a runtime dependency; consider using a version range (~=0.20.1 or >=0.20.1,<0.21).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pyproject.toml` around lines 52 - 53, Remove "uvloop>=0.22.1" from the core
pyproject.toml dependencies and add it to the
examples/vllm_serve/requirements.txt since uvloop is only used by the example
scripts (refer to examples/vllm_serve/). For "vllm==0.20.1", move it out of core
dependencies into [project.optional-dependencies] (or an extras group) and
document that it is required only for the vLLM quantization plugin invoked via
import_plugin("vllm"); also relax the pin to a range such as ~=0.20.1 or
>=0.20.1,<0.21 to avoid a fragile exact-match dependency.

Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
@rohansjoshi rohansjoshi force-pushed the rohjoshi/sparse-kernel branch from a90c9d6 to ea47408 Compare May 15, 2026 19:15
@github-actions
Copy link
Copy Markdown
Contributor

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1499/

Built to branch gh-pages at 2026-05-15 19:19 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/config.py (1)

705-715: ⚡ Quick win

Align skip-softmax field docs with the new combined preset.

SPARSE_SOFTMAX_SKIP_DEFAULT sets skip_softmax_threshold with method="triton_sparse_softmax", but the skip_softmax_threshold field description still says it is only used by triton_skip_softmax. Please update that description to include combined sparse+skip Triton usage.

♻️ Suggested doc update
-            "Tiles contributing less than this fraction are skipped entirely. "
-            "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
+            "Tiles contributing less than this fraction are skipped entirely. "
+            "Used by Triton skip-softmax modes (e.g., triton_skip_softmax and "
+            "triton_sparse_softmax when skip-softmax is enabled). Typical values: 1e-3 to 1e-1. "
             "Set to 0 to disable."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 705 - 715,
Update the documentation string for the skip_softmax_threshold field to reflect
that it is used not only by "triton_skip_softmax" but also by the combined
Triton preset "triton_sparse_softmax" (as used in SPARSE_SOFTMAX_SKIP_DEFAULT);
locate the skip_softmax_threshold description near the configuration definitions
and modify the text to state it applies to both triton_skip_softmax and
triton_sparse_softmax (combined sparse+skip Triton usage), ensuring examples or
notes mention SPARSE_SOFTMAX_SKIP_DEFAULT where appropriate.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/kernels/common/attention/triton_fa.py`:
- Line 838: The paged KV mode (is_paged = k_cache is not None) is incompatible
with autograd because the forward reads from K_cache/V_cache while the backward
uses saved k/v in ctx, leading to incorrect gradients; update attention() to
either (a) add a runtime guard that raises an error when is_paged is true and
(k.requires_grad or v.requires_grad) to prevent silent gradient corruption, or
(b) clearly document the limitation in the attention() docstring (describe that
paged KV mode is inference-only and gradients are unsupported). Reference
is_paged, k_cache, v_cache, ctx, and function attention() when making the
change.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 705-715: Update the documentation string for the
skip_softmax_threshold field to reflect that it is used not only by
"triton_skip_softmax" but also by the combined Triton preset
"triton_sparse_softmax" (as used in SPARSE_SOFTMAX_SKIP_DEFAULT); locate the
skip_softmax_threshold description near the configuration definitions and modify
the text to state it applies to both triton_skip_softmax and
triton_sparse_softmax (combined sparse+skip Triton usage), ensuring examples or
notes mention SPARSE_SOFTMAX_SKIP_DEFAULT where appropriate.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 8012e290-b808-40e5-8ad6-ec21855c1bde

📥 Commits

Reviewing files that changed from the base of the PR and between a90c9d6 and ea47408.

📒 Files selected for processing (8)
  • examples/vllm_serve/sparse_attn_worker.py
  • modelopt/torch/kernels/common/attention/triton_fa.py
  • modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
  • modelopt/torch/quantization/plugins/vllm.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • pyproject.toml
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • pyproject.toml
  • modelopt/torch/quantization/plugins/vllm.py
  • examples/vllm_serve/sparse_attn_worker.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py

kv_group_num = num_q_heads // num_kv_heads
batch = b_seq_len.shape[0]

is_paged = k_cache is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Paged KV mode is incompatible with autograd; consider adding a guard or documenting the limitation.

When is_paged=True, the forward kernel reads from K_cache/V_cache, but the backward kernels read from the contiguous k/v tensors saved in ctx. If these differ, gradients will be silently incorrect.

For the vLLM inference use case this PR targets, this is fine (no gradients needed). However, consider either:

  1. Adding a runtime guard when k.requires_grad or v.requires_grad with paged mode, or
  2. Documenting this limitation in the attention() docstring's paged KV parameters.
📝 Suggested documentation update
         k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim].
             When provided, K/V are read from paged cache via block_table
-            instead of from contiguous k/v tensors.
+            instead of from contiguous k/v tensors. **Note:** Autograd is not
+            supported in paged mode; backward gradients will be incorrect.
         v_cache: Paged V cache [num_blocks, page_size, num_kv_heads, head_dim].
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/common/attention/triton_fa.py` at line 838, The paged
KV mode (is_paged = k_cache is not None) is incompatible with autograd because
the forward reads from K_cache/V_cache while the backward uses saved k/v in ctx,
leading to incorrect gradients; update attention() to either (a) add a runtime
guard that raises an error when is_paged is true and (k.requires_grad or
v.requires_grad) to prevent silent gradient corruption, or (b) clearly document
the limitation in the attention() docstring (describe that paged KV mode is
inference-only and gradients are unsupported). Reference is_paged, k_cache,
v_cache, ctx, and function attention() when making the change.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 15, 2026

Codecov Report

❌ Patch coverage is 9.09091% with 160 lines in your changes missing coverage. Please review.
✅ Project coverage is 62.55%. Comparing base (e27f76f) to head (ea47408).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
...delopt/torch/kernels/common/attention/triton_fa.py 6.59% 85 Missing ⚠️
.../torch/sparsity/attention_sparsity/plugins/vllm.py 0.00% 74 Missing ⚠️
...kernels/sparsity/attention/skip_softmax_helpers.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (e27f76f) and HEAD (ea47408). Click for more details.

HEAD has 11 uploads less than BASE
Flag BASE (e27f76f) HEAD (ea47408)
gpu 3 1
examples 12 3
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1499       +/-   ##
===========================================
- Coverage   77.44%   62.55%   -14.89%     
===========================================
  Files         473      474        +1     
  Lines       51418    52331      +913     
===========================================
- Hits        39819    32736     -7083     
- Misses      11599    19595     +7996     
Flag Coverage Δ
examples 32.82% <5.68%> (-8.93%) ⬇️
gpu 26.71% <5.68%> (-33.60%) ⬇️
regression 15.11% <5.68%> (+0.20%) ⬆️
unit 52.58% <9.09%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants