Skip to content

[4/n] Add vLLM integration for modelopt sparse attention#1127

Open
kaix-nv wants to merge 9 commits into
mainfrom
kaix/sparse_attn_vllm_integration
Open

[4/n] Add vLLM integration for modelopt sparse attention#1127
kaix-nv wants to merge 9 commits into
mainfrom
kaix/sparse_attn_vllm_integration

Conversation

@kaix-nv
Copy link
Copy Markdown
Contributor

@kaix-nv kaix-nv commented Mar 27, 2026

What does this PR do?

Type of change: ?

New feature. Add vLLM integration for ModelOpt sparse attention with paged KV cache support.

Extends the Triton flash attention kernel (triton_fa.py) with paged KV cache support. KV cache can be read directly from vLLM's non-contiguous paged cache via block_table lookup, avoiding expensive gather-to-contiguous copies.

SparseVLLMAttention wraps vLLM's Attention layer. It lets vLLM write KV to its paged cache, then calls the ModelOpt Triton kernel with k_cache, v_cache, block_table for both prefill and decode.

SparseAttnWorker patches vLLM attention modules at model load time. SparseQuantWorker combines quantization+sparse attention. Worker selection is automatic based on env vars (SPARSE_ATTN_CFG, QUANT_CFG).

vllm_serve_sparse_attn.py launches a vLLM OpenAI-compatible server with sparse attention enabled.

Usage

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for serving models with sparse attention in vLLM, enabling efficient inference through optimized attention computation.
    • Sparse attention configuration is automatically loaded from model checkpoints.
    • Enhanced kernel support for paged KV-cache, improving memory efficiency during serving.
  • Documentation

    • Added guide for deploying sparse attention models with vLLM, including workflow and usage examples.
  • Chores

    • Updated import configuration for better code organization.

Review Change Stack

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 27, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR enables serving ModelOpt-calibrated sparse attention models in vLLM by adding paged KV-cache support to the Triton attention kernel, introducing a vLLM backend plugin that routes attention through the sparse kernel, providing configuration loading/matching utilities, integrating a worker that patches attention at load time, and adding comprehensive tests and documentation.

Changes

Sparse attention serving integration

Layer / File(s) Summary
Paged KV-cache support in Triton attention kernel
modelopt/torch/kernels/common/attention/triton_fa.py, modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
New _load_paged_k_tile and _load_paged_v_tile Triton helpers load K/V from paged caches via block tables. The forward kernel accepts K_cache, V_cache, Block_table, and constexpr IS_PAGED flag; KV loads branch to paged or contiguous paths in the main loop. The autograd wrapper detects paged mode, rejects gradients when paged, and converts skip-softmax threshold via log2 without scaling. Public attention() API accepts k_cache, v_cache, block_table, page_size parameters.
Sparse attention config loading and export
modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py, modelopt/torch/sparsity/attention_sparsity/conversion.py
New sparse_attn_config.py module provides ALGO_TO_PRESET mapping, target sparsity normalization, match_sparse_config for glob-based layer name matching, and load_from_checkpoint_metadata to extract/validate configs from HF checkpoints. export_sparse_attention_config() now also captures and exports target_sparse_ratio from sparse method instances.
vLLM sparse attention backend plugin
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
ModelOptSparseAttentionImpl.forward() translates vLLM attention metadata (query layout, sequence lengths, block tables) into Triton sparse kernel arguments, enforces prefill-only constraints, unpacks paged KV, and converts calibration parameters to skip-softmax thresholds. Backend registers as "MODELOPT_SPARSE". _clone_sparse_impl() creates sparse impl instances from vLLM FlashAttention while preserving runtime state and rejecting non-None sinks.
vLLM worker and server launcher
examples/vllm_serve/sparse_attn_worker.py, examples/vllm_serve/vllm_serve_sparse_attn.py
SparseAttnWorker.load_model() calls base loader, then patches vLLM attention implementations via _replace_attention_impl, which loads sparse config from checkpoint metadata, matches per-layer configs, and rewires each Attention.impl with sparse kernel kwargs. Server launcher script handles vLLM version compatibility, configures paths for local worker imports, and launches OpenAI server with sparse worker as default.
Tests and documentation
tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py, tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py, tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py, tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py, tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py, examples/vllm_serve/README.md, pyproject.toml
GPU tests validate paged KV threshold consistency and vLLM plugin correctness (prefill matching, unsupported config rejection, profiling mode, page size inference). Unit tests cover sparse config pattern matching, checkpoint metadata loading, impl cloning, and threshold resolution. README documents sparse attention workflow, sparse_algo mapping, and limitations. isort config adds vllm to third-party imports.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main change: adding vLLM integration for ModelOpt sparse attention, which is the primary focus of all modified and new files across the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 92.98% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: safe torch.load/numpy.load patterns, no hardcoded trust_remote_code, no eval/exec, no # nosec comments, no unsafe dependencies added.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/sparse_attn_vllm_integration

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@kaix-nv kaix-nv changed the title Add vLLM integration for modelopt sparse attention [4/n] Add vLLM integration for modelopt sparse attention Mar 27, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 27, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1127/

Built to branch gh-pages at 2026-05-16 22:24 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 27, 2026

Codecov Report

❌ Patch coverage is 43.09392% with 103 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.75%. Comparing base (81c4fb2) to head (20034a2).

Files with missing lines Patch % Lines
.../torch/sparsity/attention_sparsity/plugins/vllm.py 0.00% 73 Missing ⚠️
...delopt/torch/kernels/common/attention/triton_fa.py 46.66% 24 Missing ⚠️
...y/attention_sparsity/plugins/sparse_attn_config.py 89.36% 5 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1127      +/-   ##
==========================================
- Coverage   76.91%   76.75%   -0.17%     
==========================================
  Files         473      475       +2     
  Lines       51439    51583     +144     
==========================================
+ Hits        39566    39593      +27     
- Misses      11873    11990     +117     
Flag Coverage Δ
examples 41.50% <16.02%> (+0.80%) ⬆️
gpu 59.58% <17.67%> (-0.76%) ⬇️
regression 14.94% <0.00%> (+0.03%) ⬆️
unit 52.55% <30.38%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attn_vllm_integration branch 6 times, most recently from 26c6b3b to e4c4680 Compare March 28, 2026 23:05
@kaix-nv kaix-nv marked this pull request as ready for review March 30, 2026 21:05
@kaix-nv kaix-nv requested review from a team as code owners March 30, 2026 21:05
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py (1)

290-336: Assert decode correctness, not just finiteness.

This case still passes if paged decode reads the wrong blocks or masks the wrong keys, because it only checks shape/NaNs. Please compare out against a contiguous decode reference here as well.

As per coding guidelines, "Write tests using pytest for all new features and examples; organize tests into tests/unit (fast CPU-based), tests/gpu (fast GPU-based), tests/gpu_megatron (Megatron-Core), tests/gpu_trtllm (TensorRT-LLM), and tests/examples (integration tests)" and "All test coverage checks in PRs must pass for new features and examples."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py` around
lines 290 - 336, The test currently only checks shape/NaNs for paged decode; add
a correctness assertion by computing a contiguous-reference decode (call the
same attention function with the original k_flat/v_flat and without
k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and
compare outputs from test_paged_decode to that reference using
torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to
ensure values match; keep the existing shape and NaN checks and reuse q_flat,
k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and
scale so the only difference is paged vs contiguous.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 137-140: The worker assumes a different kv_cache axis order than
the vllm plugin; instead of slicing kv_cache[:, 0] and kv_cache[:, 1], normalize
to the same layout used elsewhere by splitting with kv_cache.unbind(0) so
k_cache and v_cache come from that unbind; update page_size to derive from the
resulting k_cache shape (e.g., k_cache.shape[1]) and ensure any downstream uses
of k_cache/v_cache match this normalized layout (references: kv_cache, k_cache,
v_cache and the existing kv_cache.unbind(0) usage in the vllm plugin).
- Around line 297-300: The import inside compile_or_warm_up_model currently uses
a relative import ("from .fakequant_worker import _fakequant_run_prolog_worker,
quant_config") which fails when the module is loaded as a top-level module;
change it to a top-level/absolute import or dynamic import so the code works
whether loaded as a package or directly (e.g., use "from fakequant_worker import
_fakequant_run_prolog_worker, quant_config" or use
importlib.import_module("fakequant_worker") and grab the attributes). Update the
import used by compile_or_warm_up_model and any callers expecting
_fakequant_run_prolog_worker and quant_config accordingly so the class
SparseQuantWorker can be imported as a top-level module without ImportError.
- Around line 273-284: The replacement sets sliding_window=None which disables
local/sliding-window attention; update the instantiation of
ModelOptSparseAttentionImpl (the assignment to module.impl) to pass through the
original value (old_impl.sliding_window) instead of None, or add a guard that
rejects/raises when old_impl.sliding_window is non-None and unsupported; ensure
you reference old_impl.sliding_window and ModelOptSparseAttentionImpl in the
change so sliding-window behavior is preserved or explicitly handled.
- Around line 176-183: The decode call to triton_attention in
sparse_attn_worker.py is using is_causal=True which causes incorrect masking for
paged KV; change the triton_attention invocation (the call that sets
q=query[offset: offset+nd], k=query[:0], v=query[:0],
b_start_loc=dm.query_start_loc, b_seq_len=..., max_input_len=1) to pass
is_causal=False instead, matching the decode path in
modelopt/torch/kernels/hf_triton_attention.py so later cached KV tiles are not
truncated.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 945-958: The code currently allows paged-mode (is_paged True) but
does not disable/guard autograd, causing backward to recompute using contiguous
K/V and dummy b_start_loc_k and produce wrong gradients; update the forward
entrypoint that sets is_paged (look for the block using is_paged, b_start_loc_k
and b_start_loc) to explicitly disallow autograd in paged mode by either raising
a clear exception when torch.is_grad_enabled() (or when requires_grad on inputs)
and is_paged is True, or by wrapping the paged-mode path in torch.no_grad() and
documenting that backward is unsupported; ensure the guard references is_paged
and b_start_loc_k so callers cannot silently run backward with dummy
b_start_loc_k.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 81-99: The current loop builds a single sparse_kw from the first
enabled entry in _sparse_config["sparse_cfg"] and breaks, which ignores
layer-specific patterns; instead, for each module use name-based matching to
select the correct layer_cfg (reuse the matching logic from
examples/vllm_serve/sparse_attn_worker.py::_match_sparse_config or call that
helper) and then build a per-module sparse_kw from that matched layer_cfg
(respecting fields like sparsity_n, sparsity_m, num_sink_tokens,
dense_window_size, skip_softmax_threshold); do not break out of the loop—apply
the matched config only to the current module or stash sparse_kw on the module
instance before swapping implementations so multiple patterns in the calibration
file are handled correctly.

---

Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py`:
- Around line 290-336: The test currently only checks shape/NaNs for paged
decode; add a correctness assertion by computing a contiguous-reference decode
(call the same attention function with the original k_flat/v_flat and without
k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and
compare outputs from test_paged_decode to that reference using
torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to
ensure values match; keep the existing shape and NaN checks and reuse q_flat,
k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and
scale so the only difference is paged vs contiguous.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b27921dd-dcd2-4ce3-acc9-28816a998e1f

📥 Commits

Reviewing files that changed from the base of the PR and between 24ceba6 and e4c4680.

📒 Files selected for processing (7)
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • pyproject.toml
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py

Comment thread examples/vllm_serve/sparse_attn_worker.py Outdated
Comment thread examples/vllm_serve/sparse_attn_worker.py Outdated
Comment thread examples/vllm_serve/sparse_attn_worker.py Outdated
Comment thread examples/vllm_serve/sparse_attn_worker.py Outdated
Comment thread modelopt/torch/kernels/common/attention/triton_fa.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)

145-152: ⚠️ Potential issue | 🟠 Major

Preserve existing sliding_window when replacing attention impl

Line 151 hardcodes sliding_window=None, which can silently change local/sliding-window attention behavior. Pass through old_impl.sliding_window (or explicitly reject unsupported non-None values).

Proposed fix
             module.impl = ModelOptSparseAttentionImpl(
                 num_heads=old_impl.num_heads,
                 head_size=old_impl.head_size,
                 scale=old_impl.scale,
                 num_kv_heads=old_impl.num_kv_heads,
                 alibi_slopes=old_impl.alibi_slopes,
-                sliding_window=None,
+                sliding_window=old_impl.sliding_window,
                 kv_cache_dtype=old_impl.kv_cache_dtype,
                 logits_soft_cap=old_impl.logits_soft_cap,
                 attn_type=old_impl.attn_type,
                 kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 145 - 152, The
construction of ModelOptSparseAttentionImpl currently hardcodes
sliding_window=None which can change attention behavior; update the initializer
in the replacement code to pass through old_impl.sliding_window (i.e., use
sliding_window=old_impl.sliding_window) or, if non-None values are unsupported,
explicitly check old_impl.sliding_window and raise an informative error before
creating ModelOptSparseAttentionImpl; reference ModelOptSparseAttentionImpl,
old_impl, and old_impl.sliding_window to locate and fix the code.
🧹 Nitpick comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)

111-119: Remove or wire _match_sparse_config to avoid dead-path drift

_match_sparse_config is currently unused, which makes behavior harder to reason about and can drift from real matching logic. Either use it in patching/selection flow or remove it until needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 111 - 119, The helper
function _match_sparse_config is unused and creates dead-path drift; either
remove this function or wire it into the sparse patch/selection flow by
replacing the current pattern-matching logic with a call to
_match_sparse_config(module_name, sparse_cfg) (or call it from wherever sparse
layer configs are looked up) so that matching behavior is centralized; update
any callers that currently duplicate pattern checks to use _match_sparse_config
and remove dead duplicates if you choose to keep it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 81-83: The code currently returns whatever getattr(mtsa, cfg_name,
None) yields (cfg) and may hand back non-mapping objects; update the getter to
validate that cfg is a dict before returning (use isinstance(cfg, dict)) and
otherwise raise a clear error (e.g., ValueError) indicating that the symbol
named by cfg_name in modelopt.torch.sparsity.attention_sparsity must be a dict;
reference the getattr(mtsa, cfg_name, None) call and the cfg variable to locate
the change.
- Around line 92-106: The _load_sparse_config function currently trusts
arbitrary JSON from SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded
object and each layer_cfg: assert the top-level JSON is a dict, allowed
top-level keys are strings and either "calibration" or pattern names, and each
layer_cfg is a dict before applying defaults; enforce allowed keys (e.g.,
"method", "backend", "enable", numeric sparsity params) and bounds for numeric
fields (e.g., sparsity percentages 0–100, integer layer indices >=0, and
reasonable max limits) and reject or clamp out-of-range values, raising a clear
exception on invalid schema; keep the existing defaults
(method="triton_sparse_softmax", backend="triton", enable=True) for valid
entries and ensure sparse_cfg["default"] = {"enable": False} remains set.

---

Duplicate comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 145-152: The construction of ModelOptSparseAttentionImpl currently
hardcodes sliding_window=None which can change attention behavior; update the
initializer in the replacement code to pass through old_impl.sliding_window
(i.e., use sliding_window=old_impl.sliding_window) or, if non-None values are
unsupported, explicitly check old_impl.sliding_window and raise an informative
error before creating ModelOptSparseAttentionImpl; reference
ModelOptSparseAttentionImpl, old_impl, and old_impl.sliding_window to locate and
fix the code.

---

Nitpick comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 111-119: The helper function _match_sparse_config is unused and
creates dead-path drift; either remove this function or wire it into the sparse
patch/selection flow by replacing the current pattern-matching logic with a call
to _match_sparse_config(module_name, sparse_cfg) (or call it from wherever
sparse layer configs are looked up) so that matching behavior is centralized;
update any callers that currently duplicate pattern checks to use
_match_sparse_config and remove dead duplicates if you choose to keep it.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cf1d8866-f2b3-4adf-9b4a-eb396967c80e

📥 Commits

Reviewing files that changed from the base of the PR and between e4c4680 and 4644bf5.

📒 Files selected for processing (1)
  • examples/vllm_serve/sparse_attn_worker.py

Comment on lines +81 to +83
cfg = getattr(mtsa, cfg_name, None)
if cfg is not None:
return cfg
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate preset object type before returning it

If SPARSE_ATTN_CFG matches a non-dict symbol in modelopt.torch.sparsity.attention_sparsity, cfg is returned as-is and later consumed as a mapping, which can crash at runtime. Add an explicit isinstance(cfg, dict) guard and fail fast with a clear error.

Proposed fix
     cfg = getattr(mtsa, cfg_name, None)
     if cfg is not None:
-        return cfg
+        if not isinstance(cfg, dict):
+            raise ValueError(
+                f"Invalid sparse config preset '{cfg_name}': expected dict, got {type(cfg).__name__}."
+            )
+        return cfg
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 81 - 83, The code
currently returns whatever getattr(mtsa, cfg_name, None) yields (cfg) and may
hand back non-mapping objects; update the getter to validate that cfg is a dict
before returning (use isinstance(cfg, dict)) and otherwise raise a clear error
(e.g., ValueError) indicating that the symbol named by cfg_name in
modelopt.torch.sparsity.attention_sparsity must be a dict; reference the
getattr(mtsa, cfg_name, None) call and the cfg variable to locate the change.

Comment thread examples/vllm_serve/sparse_attn_worker.py Outdated
@kaix-nv kaix-nv force-pushed the kaix/sparse_attn_vllm_integration branch from 4644bf5 to 54079b8 Compare March 31, 2026 00:46
@kaix-nv kaix-nv requested review from Edwardf0t1 and jingyu-ml March 31, 2026 01:49
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: Adds paged KV cache support to the ModelOpt Triton flash attention kernel, a vLLM sparse attention plugin, worker classes for vLLM integration, and GPU tests for the paged KV path.

Issues Found:

  1. [Correctness] Backward pass silently broken for paged KVtriton_fa.py, forward() saves k, v to ctx for backward, but in paged mode these are dummy/empty tensors (e.g., k_dummy = torch.empty(0, ...) from the vLLM plugin). If .backward() is ever called on paged-mode output, gradients for dK/dV will be computed against the dummy tensors, producing silently incorrect results. The backward should either raise NotImplementedError("Backward not supported for paged KV cache") when is_paged=True, or the limitation should be clearly documented in the attention() docstring. The current code just adds None return placeholders for the 4 new args without any guard.

  2. [Correctness] Unused import_plugin import in plugins/__init__.py — The diff adds from modelopt.torch.utils import import_plugin but it is never used in the file. The existing __init__.py doesn't use this import, and vllm.py imports directly from vLLM. This is dead code that should be removed, or if the intent was to use import_plugin for conditional vLLM import (as other plugin modules do), that wiring is missing.

  3. [Correctness] _build_sparse_config fallback logic is confusingsparse_attn_worker.py:78-86: getattr(mtsa, cfg_name, None) is tried first, but then SPARSE_SOFTMAX_DEFAULT falls through to the hardcoded _DEFAULT_SPARSE_CFG dict. If mtsa actually defines SPARSE_SOFTMAX_DEFAULT in the future, the getattr path would return it and the hardcoded default would never be used, leading to silent behavior divergence. Consider making the precedence explicit or removing the duplication.

  4. [Readability] Duplicated paged V-tile loadingtriton_fa.py _attn_fwd kernel: _load_paged_v_tile is called with identical arguments in two separate branches (skip-softmax path around line 476 and standard path around line 518). While Triton JIT constraints may require this, the duplicated 20-line call blocks are a readability concern. A comment explaining why the duplication is necessary would help.

  5. [Tests] No backward/gradient test for paged modetest_triton_fa_paged.py only tests forward correctness. Given that paged mode changes the autograd Function's forward signature and backward is not updated to support paged KV, there should be at minimum a test asserting that backward raises an error (if a guard is added per issue #1), or this should be explicitly documented as inference-only.

  6. [Tests] No integration test for vLLM pluginModelOptSparseAttentionImpl.forward() and ModelOptSparseAttentionBackend have no test coverage. These are the most integration-critical new classes. Even a mock-based unit test validating the metadata translation logic (cu_seqlens_qb_start_loc, seq_lensb_seq_len_k) would catch regressions.

  7. [Tests] Inconsistent b_start_loc_k handling across teststest_paged_matches_contiguous passes explicit b_start_loc_k=locs_k, while test_paged_no_nan omits it (relying on the dummy-zeros fallback). Both should use the same calling convention to avoid masking bugs in the fallback path.

  8. [Readability] if threshold: truthiness checksparse_attn_worker.py:146: if threshold: evaluates False for both None and 0.0. While 0.0 correctly means "disabled", this is subtle. if threshold is not None would be clearer about intent (let the kernel handle the 0.0 case).

Suggestions:

  • Consider adding requires_grad=False to k_dummy and v_dummy in the vLLM plugin to make the inference-only intent explicit and catch accidental backward calls early.
  • The _DEFAULT_SPARSE_CFG hardcoded in the worker could reference mtsa constants instead, reducing drift risk.
  • test_paged_decode uses q_flat with shape [batch, num_heads, head_dim] (3D) rather than the expected [total_q_tokens, num_heads, head_dim]. This works because batch * 1 = batch tokens, but the shape semantics are confusing — using q_flat.reshape(batch, num_heads, head_dim) explicitly or adding a comment would clarify.

Overall Assessment: The core kernel extension (paged KV tile loaders + IS_PAGED branching) is well-structured and the tests verify forward correctness across multiple configurations. The main blocking concern is the silent backward-pass breakage for paged mode — this needs at minimum a guard or documentation since the function is part of the public attention() API with autograd support. The unused import is a minor cleanup. The vLLM plugin lacks test coverage but is in examples/ territory so is less critical.

yeyu-nvidia added a commit that referenced this pull request Apr 3, 2026
Implements a ModelOpt sparse attention plugin for diffusers WAN models,
building on the triton_fa kernel infrastructure from PR #1127.

New files:
- modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
  - ModelOptWanAttnProcessor: replaces WanAttnProcessor, calls
    triton_fa.attention() directly with BSND->varlen conversion.
    Supports I2V cross-attention path and N:M/skip-softmax sparsity.
  - WanSparseAttentionModule: subclasses SparseAttentionModule, installs
    the Triton processor and syncs enabled state on each forward.
  - register_wan_sparse_attention(): plugin callback auto-registered in
    CUSTOM_MODEL_PLUGINS; fires during mtsa.sparsify().

Updated files:
- plugins/__init__.py: lazy-import diffusers plugin via import_plugin()
- config.py: add "diffusers_triton" to validate_backend whitelist
- conversion.py: skip HF attn registration for "diffusers_triton" backend
- wan2_sage_attention.py: add triton-sparse and triton-skip kernel options
  backed by mtsa.sparsify() with diffusers_triton backend

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
yeyu-nvidia added a commit that referenced this pull request Apr 8, 2026
Implements a ModelOpt sparse attention plugin for diffusers WAN models,
building on the triton_fa kernel infrastructure from PR #1127.

New files:
- modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
  - ModelOptWanAttnProcessor: replaces WanAttnProcessor, calls
    triton_fa.attention() directly with BSND->varlen conversion.
    Supports I2V cross-attention path and N:M/skip-softmax sparsity.
  - WanSparseAttentionModule: subclasses SparseAttentionModule, installs
    the Triton processor and syncs enabled state on each forward.
  - register_wan_sparse_attention(): plugin callback auto-registered in
    CUSTOM_MODEL_PLUGINS; fires during mtsa.sparsify().

Updated files:
- plugins/__init__.py: lazy-import diffusers plugin via import_plugin()
- config.py: add "diffusers_triton" to validate_backend whitelist
- conversion.py: skip HF attn registration for "diffusers_triton" backend
- wan2_sage_attention.py: add triton-sparse and triton-skip kernel options
  backed by mtsa.sparsify() with diffusers_triton backend

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/sparse_attn_vllm_integration branch from 54079b8 to bbac896 Compare May 11, 2026 18:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Around line 81-87: The current conditional treats a quant-only setup as the
"none set" case, causing an incorrect warning; update the branch logic around
has_quant, has_sparse and worker_cls so quant-only is handled explicitly: add an
elif has_quant branch that sets worker_cls to a suitable quant-only class (e.g.,
"sparse_attn_worker.QuantWorker") and only use the warning in the final else
when neither has_sparse nor has_quant is true, or change the warning message to
accurately reflect the detected state; refer to the boolean flags has_quant and
has_sparse and the variable worker_cls to locate and modify the logic.
- Around line 31-36: Move the hard imports of optional packages (uvloop, vllm
and the vllm.entrypoints imports run_server and make_arg_parser) out of module
scope and into the main() function (or a helper invoked by main), using the
project's import_plugin() helper or try/except ImportError to lazily load them
and provide a clear error message if missing; update any top-level references to
run_server/make_arg_parser to use the locally imported names inside main() so
the module can be imported even when those extras are not installed.
- Around line 72-76: The PYTHONPATH construction currently appends repo_root
using a hardcoded ":" which is non-portable and creates a leading separator when
PYTHONPATH is empty; update the logic that sets os.environ["PYTHONPATH"] to use
os.pathsep and only prepend the separator when an existing value is non-empty
(i.e., read current = os.environ.get("PYTHONPATH", "") and set it to current +
os.pathsep + repo_root if current else repo_root) so that repo_root, sys.path
insertion, and environment modification remain correct and platform-independent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b08b8b7f-41eb-4c7a-bfee-feeb508721ac

📥 Commits

Reviewing files that changed from the base of the PR and between 54079b8 and bbac896.

📒 Files selected for processing (2)
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/vllm_serve/sparse_attn_worker.py

Comment on lines +31 to +36
import uvloop
import vllm
from packaging import version
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

# First, locate and read the file in question
cd examples/vllm_serve && ls -la && wc -l vllm_serve_sparse_attn.py

Repository: NVIDIA/Model-Optimizer

Length of output: 771


🏁 Script executed:

# Read the file with line numbers to see the actual imports and structure
head -100 examples/vllm_serve/vllm_serve_sparse_attn.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4153


🏁 Script executed:

# Search for import_plugin usage in the codebase to understand the pattern
rg "import_plugin" --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 21147


🏁 Script executed:

# Check pyproject.toml for optional extras definition
grep -A 20 "optional-dependencies\|extras" pyproject.toml | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 1570


Move vLLM and uvloop imports into main() function to handle missing optional dependencies gracefully.

Lines 31–35 perform hard module-level imports of uvloop, vllm, and related entrypoints. This breaks the module import in environments without these optional packages installed. Relocate these imports inside main() with appropriate error handling, or use import_plugin() as established throughout the codebase for optional integrations.

As per coding guidelines: "**/*.py: Use optional dependencies gated by install extras; avoid hard imports at module level for optional features" and "Load optional integrations lazily via import_plugin()."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 31 - 36, Move the
hard imports of optional packages (uvloop, vllm and the vllm.entrypoints imports
run_server and make_arg_parser) out of module scope and into the main() function
(or a helper invoked by main), using the project's import_plugin() helper or
try/except ImportError to lazily load them and provide a clear error message if
missing; update any top-level references to run_server/make_arg_parser to use
the locally imported names inside main() so the module can be imported even when
those extras are not installed.

Comment thread examples/vllm_serve/vllm_serve_sparse_attn.py Outdated
Comment thread examples/vllm_serve/vllm_serve_sparse_attn.py Outdated
@kaix-nv kaix-nv requested a review from shengliangxu May 11, 2026 23:11
@kaix-nv kaix-nv enabled auto-merge (squash) May 11, 2026 23:28
Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kaix-nv kaix-nv requested a review from a team May 11, 2026 23:35
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the README of the example for sparse attention?

@@ -0,0 +1,226 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this worker be merged with fakequnt_worker? Ideally, we would like a unified entry point for both quantization and sparsity, so we can simulate quantization and sparisty and the same time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseQuantWorker in sparse_attn_worker.py already supports this. Currently, we have three workers:

  • FakeQuantWorker in fakequant_worker.py (quantization only)
  • SparseAttnWorker in sparse_attn_worker.py (sparsity only)
  • SparseQuantWorker in sparse_attn_worker.py (quantization + sparsity) — this is already the unified implementation

We can consolidate these three workers into a single unified worker, such as ModelOptWorker, in a follow-up PR.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review. Several critical items from prior rounds are still unresolved:

Unresolved critical:

  1. Backward pass silently incorrect in paged mode (flagged by cjluo-nv and CodeRabbit). _Attention.forward() does ctx.save_for_backward(q, k, v, ...) and the backward kernels index K/V contiguously using b_start_loc_k. In paged mode, callers (plugins/vllm.py) pass k_dummy = torch.empty(0, num_kv_heads, head_size, ...) and b_start_loc_k=None (then zero-filled). If anything ever calls .backward() on paged output, backward will either OOB-load or compute garbage gradients silently. There is still no guard. At minimum add if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad): raise NotImplementedError(...) in forward, and document this limitation in the attention() docstring.

  2. No tests for the new vLLM plugin (ModelOptSparseAttentionImpl.forward, ModelOptSparseAttentionBackend). The paged kernel has good tests, but the integration-critical layer (metadata translation query_start_locb_start_loc, seq_lensb_seq_len_k, KV cache unbind / page_size derivation) is untested. Given the "workers disagree on kv_cache axis order" bug found last round, a small mock-based unit test here is important.

  3. README not updated as meenchen requested.

Unresolved minor / cleanup:
4. modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py adds from modelopt.torch.utils import import_plugin but never uses it. Either wire vllm.py in via import_plugin(".vllm", ...) (matching other plugin modules) or remove the import. Currently plugins/vllm.py is only reachable from the example.
5. vllm_serve_sparse_attn.py line 75: os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" is not portable and creates a leading : when PYTHONPATH is unset. Use os.pathsep and conditional join.
6. vllm_serve_sparse_attn.py warning branch (line 85–86) prints "No SPARSE_ATTN_CFG or QUANT_CFG set" even when QUANT_CFG is set but sparse is not; the message is inaccurate for the quant-only case.
7. sparse_attn_worker.py _build_sparse_config has a dead branch: SPARSE_SOFTMAX_DEFAULT is exported from mtsa.config, so getattr(mtsa, "SPARSE_SOFTMAX_DEFAULT", None) already returns the canonical preset and the local _DEFAULT_SPARSE_CFG fallback can never be reached. Either remove the local fallback or explain why it's intentionally divergent.
8. sparse_attn_worker.py line 151: still passes sliding_window=None to ModelOptSparseAttentionImpl.__init__ and then fixes up with new_impl.sliding_window = old_impl.sliding_window afterward. That relies on FlashAttentionImpl.__init__ not using the value for anything else (e.g., internal config). Safer to pass old_impl.sliding_window directly through the constructor, or assert old_impl.sliding_window is unset.

Design review (required per complexity gate): this is PR 4/n of a subsystem landed incrementally, and the kernel extension + integration style (impl swap rather than a new vLLM attention backend registered via entry points) is a reasonable reuse of FlashAttentionImpl. Good: ModelOptSparseAttentionImpl inherits to reuse __init__ and do_kv_cache_update, only overriding forward. That keeps the scope focused. No second system duplicating existing in-repo sparse-attention infra is introduced.

Please address #1 (backward guard) and #3 (README) before merging; #2 (plugin tests) is strongly recommended given the axis-order bug caught last round.

skip_softmax_raw_threshold,
measure_sparsity,
k_cache,
v_cache,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Paged mode is forward-only, but there is no guard: ctx.save_for_backward(q, k, v, ...) stores whatever the caller passed (dummy torch.empty(0, num_kv_heads, head_size) from plugins/vllm.py), and backward then indexes K/V contiguously through b_start_loc_k. A stray .backward() on paged output will OOB-load or produce silently wrong gradients.

Suggest adding right after is_paged = k_cache is not None:

if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad):
    raise NotImplementedError(
        "Paged KV cache path is forward-only; backward is not implemented."
    )

and document the limitation in the attention() docstring.


"""Plugins for sparse attention integration with various frameworks."""

from modelopt.torch.utils import import_plugin
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

import_plugin is imported here but never used (and vllm.py is not wired in via the usual import_plugin(".vllm", ...) pattern). Either register vllm.py through this mechanism (matching other plugin modules like huggingface) or drop this dead import.

sparsity_n = layer_cfg.get("sparsity_n", 0)
if sparsity_n > 0:
sparse_kw["sparsity_n"] = sparsity_n
sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Still passing sliding_window=None into __init__ and patching new_impl.sliding_window afterward. That only works if FlashAttentionImpl.__init__ doesn't use the value internally (e.g. for backend selection or capability checks). Safer to pass old_impl.sliding_window directly — and if the comment about "can't reverse it" is true, at least assert the old impl's value is either None or the raw tuple so future vLLM versions don't silently break this workaround.

if cfg is not None:
return cfg
# Use built-in default if name matches
if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Dead fallback: SPARSE_SOFTMAX_DEFAULT is exported from modelopt.torch.sparsity.attention_sparsity.config (and re-exported from mtsa), so getattr(mtsa, 'SPARSE_SOFTMAX_DEFAULT', None) will always return the canonical preset — _DEFAULT_SPARSE_CFG is unreachable for that name. Either delete _DEFAULT_SPARSE_CFG or document why it intentionally differs from mtsa.SPARSE_SOFTMAX_DEFAULT (e.g. dense_window_size: 1 vs 64).

repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Non-portable (: is POSIX-only) and prepends an empty path component when PYTHONPATH is unset. Use:

current = os.environ.get("PYTHONPATH")
os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root

elif has_sparse:
worker_cls = "sparse_attn_worker.SparseAttnWorker"
else:
print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

This branch is hit when QUANT_CFG/KV_QUANT_CFG is set but sparse env vars are not, and the message "No SPARSE_ATTN_CFG or QUANT_CFG set" is then false. Either route quant-only to fakequant_worker.FakeQuantWorker here, or make the message say "No SPARSE_ATTN_CFG set — running standard vLLM (quant-only is not handled by this launcher)".

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
examples/vllm_serve/vllm_serve_sparse_attn.py (1)

35-45: ⚡ Quick win

Consider moving optional imports into main() for graceful error handling.

Lines 35–39 perform hard imports of uvloop, vllm, and vLLM entrypoints at module level. If these optional packages are not installed, importing this module will fail immediately. Moving these imports into main() with a try/except block would allow for clearer error messages and better align with best practices for optional dependencies.

♻️ Proposed refactor
-import uvloop
-import vllm
-from packaging import version
-from vllm.entrypoints.openai.api_server import run_server
-from vllm.entrypoints.openai.cli_args import make_arg_parser
-
-vllm_version = version.parse(vllm.__version__)
-if vllm_version <= version.parse("0.11.0"):
-    from vllm.utils import FlexibleArgumentParser
-else:
-    from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-
 def main():
     """Launch vLLM with sparse attention worker."""
+    try:
+        import uvloop
+        import vllm
+        from packaging import version
+        from vllm.entrypoints.openai.api_server import run_server
+        from vllm.entrypoints.openai.cli_args import make_arg_parser
+    except ImportError as e:
+        print(f"Error: {e}")
+        print("Please install vLLM and uvloop: pip install vllm uvloop")
+        sys.exit(1)
+
+    vllm_version = version.parse(vllm.__version__)
+    if vllm_version <= version.parse("0.11.0"):
+        from vllm.utils import FlexibleArgumentParser
+    else:
+        from vllm.utils.argparse_utils import FlexibleArgumentParser
+
     parser = FlexibleArgumentParser(description="vLLM model server with sparse attention")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 35 - 45, The
module-level imports of optional dependencies (uvloop, vllm, run_server,
make_arg_parser, and FlexibleArgumentParser) should be moved into the main
entrypoint (e.g., main()) and wrapped in a try/except so missing packages raise
a clear, user-friendly error instead of failing on import; update the code to
import uvloop, vllm, vllm.entrypoints.openai.api_server.run_server,
vllm.entrypoints.openai.cli_args.make_arg_parser, and the FlexibleArgumentParser
variant inside main(), detect vllm version as before, and on
ImportError/ModuleNotFoundError print a concise message indicating which package
is missing and how to install it, then exit gracefully.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Around line 35-45: The module-level imports of optional dependencies (uvloop,
vllm, run_server, make_arg_parser, and FlexibleArgumentParser) should be moved
into the main entrypoint (e.g., main()) and wrapped in a try/except so missing
packages raise a clear, user-friendly error instead of failing on import; update
the code to import uvloop, vllm, vllm.entrypoints.openai.api_server.run_server,
vllm.entrypoints.openai.cli_args.make_arg_parser, and the FlexibleArgumentParser
variant inside main(), detect vllm version as before, and on
ImportError/ModuleNotFoundError print a concise message indicating which package
is missing and how to install it, then exit gracefully.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: e2dd2197-8fab-46c3-a984-bd498244f350

📥 Commits

Reviewing files that changed from the base of the PR and between bbac896 and 252d8fb.

📒 Files selected for processing (13)
  • examples/vllm_serve/README.md
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/common/attention/triton_fa.py
  • modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
  • examples/vllm_serve/README.md

kaix-nv added 9 commits May 16, 2026 10:45
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/sparse_attn_vllm_integration branch from 252d8fb to 20034a2 Compare May 16, 2026 22:21
@kaix-nv kaix-nv requested a review from a team as a code owner May 16, 2026 22:21
@kaix-nv kaix-nv disabled auto-merge May 16, 2026 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants