[4/n] Add vLLM integration for modelopt sparse attention#1127
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR enables serving ModelOpt-calibrated sparse attention models in vLLM by adding paged KV-cache support to the Triton attention kernel, introducing a vLLM backend plugin that routes attention through the sparse kernel, providing configuration loading/matching utilities, integrating a worker that patches attention at load time, and adding comprehensive tests and documentation. ChangesSparse attention serving integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1127 +/- ##
==========================================
- Coverage 76.91% 76.75% -0.17%
==========================================
Files 473 475 +2
Lines 51439 51583 +144
==========================================
+ Hits 39566 39593 +27
- Misses 11873 11990 +117
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
26c6b3b to
e4c4680
Compare
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py (1)
290-336: Assert decode correctness, not just finiteness.This case still passes if paged decode reads the wrong blocks or masks the wrong keys, because it only checks shape/NaNs. Please compare
outagainst a contiguous decode reference here as well.As per coding guidelines, "Write tests using pytest for all new features and examples; organize tests into
tests/unit(fast CPU-based),tests/gpu(fast GPU-based),tests/gpu_megatron(Megatron-Core),tests/gpu_trtllm(TensorRT-LLM), andtests/examples(integration tests)" and "All test coverage checks in PRs must pass for new features and examples."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py` around lines 290 - 336, The test currently only checks shape/NaNs for paged decode; add a correctness assertion by computing a contiguous-reference decode (call the same attention function with the original k_flat/v_flat and without k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and compare outputs from test_paged_decode to that reference using torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to ensure values match; keep the existing shape and NaN checks and reuse q_flat, k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and scale so the only difference is paged vs contiguous.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 137-140: The worker assumes a different kv_cache axis order than
the vllm plugin; instead of slicing kv_cache[:, 0] and kv_cache[:, 1], normalize
to the same layout used elsewhere by splitting with kv_cache.unbind(0) so
k_cache and v_cache come from that unbind; update page_size to derive from the
resulting k_cache shape (e.g., k_cache.shape[1]) and ensure any downstream uses
of k_cache/v_cache match this normalized layout (references: kv_cache, k_cache,
v_cache and the existing kv_cache.unbind(0) usage in the vllm plugin).
- Around line 297-300: The import inside compile_or_warm_up_model currently uses
a relative import ("from .fakequant_worker import _fakequant_run_prolog_worker,
quant_config") which fails when the module is loaded as a top-level module;
change it to a top-level/absolute import or dynamic import so the code works
whether loaded as a package or directly (e.g., use "from fakequant_worker import
_fakequant_run_prolog_worker, quant_config" or use
importlib.import_module("fakequant_worker") and grab the attributes). Update the
import used by compile_or_warm_up_model and any callers expecting
_fakequant_run_prolog_worker and quant_config accordingly so the class
SparseQuantWorker can be imported as a top-level module without ImportError.
- Around line 273-284: The replacement sets sliding_window=None which disables
local/sliding-window attention; update the instantiation of
ModelOptSparseAttentionImpl (the assignment to module.impl) to pass through the
original value (old_impl.sliding_window) instead of None, or add a guard that
rejects/raises when old_impl.sliding_window is non-None and unsupported; ensure
you reference old_impl.sliding_window and ModelOptSparseAttentionImpl in the
change so sliding-window behavior is preserved or explicitly handled.
- Around line 176-183: The decode call to triton_attention in
sparse_attn_worker.py is using is_causal=True which causes incorrect masking for
paged KV; change the triton_attention invocation (the call that sets
q=query[offset: offset+nd], k=query[:0], v=query[:0],
b_start_loc=dm.query_start_loc, b_seq_len=..., max_input_len=1) to pass
is_causal=False instead, matching the decode path in
modelopt/torch/kernels/hf_triton_attention.py so later cached KV tiles are not
truncated.
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 945-958: The code currently allows paged-mode (is_paged True) but
does not disable/guard autograd, causing backward to recompute using contiguous
K/V and dummy b_start_loc_k and produce wrong gradients; update the forward
entrypoint that sets is_paged (look for the block using is_paged, b_start_loc_k
and b_start_loc) to explicitly disallow autograd in paged mode by either raising
a clear exception when torch.is_grad_enabled() (or when requires_grad on inputs)
and is_paged is True, or by wrapping the paged-mode path in torch.no_grad() and
documenting that backward is unsupported; ensure the guard references is_paged
and b_start_loc_k so callers cannot silently run backward with dummy
b_start_loc_k.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 81-99: The current loop builds a single sparse_kw from the first
enabled entry in _sparse_config["sparse_cfg"] and breaks, which ignores
layer-specific patterns; instead, for each module use name-based matching to
select the correct layer_cfg (reuse the matching logic from
examples/vllm_serve/sparse_attn_worker.py::_match_sparse_config or call that
helper) and then build a per-module sparse_kw from that matched layer_cfg
(respecting fields like sparsity_n, sparsity_m, num_sink_tokens,
dense_window_size, skip_softmax_threshold); do not break out of the loop—apply
the matched config only to the current module or stash sparse_kw on the module
instance before swapping implementations so multiple patterns in the calibration
file are handled correctly.
---
Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py`:
- Around line 290-336: The test currently only checks shape/NaNs for paged
decode; add a correctness assertion by computing a contiguous-reference decode
(call the same attention function with the original k_flat/v_flat and without
k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and
compare outputs from test_paged_decode to that reference using
torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to
ensure values match; keep the existing shape and NaN checks and reuse q_flat,
k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and
scale so the only difference is paged vs contiguous.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b27921dd-dcd2-4ce3-acc9-28816a998e1f
📒 Files selected for processing (7)
examples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pypyproject.tomltests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)
145-152:⚠️ Potential issue | 🟠 MajorPreserve existing
sliding_windowwhen replacing attention implLine 151 hardcodes
sliding_window=None, which can silently change local/sliding-window attention behavior. Pass throughold_impl.sliding_window(or explicitly reject unsupported non-None values).Proposed fix
module.impl = ModelOptSparseAttentionImpl( num_heads=old_impl.num_heads, head_size=old_impl.head_size, scale=old_impl.scale, num_kv_heads=old_impl.num_kv_heads, alibi_slopes=old_impl.alibi_slopes, - sliding_window=None, + sliding_window=old_impl.sliding_window, kv_cache_dtype=old_impl.kv_cache_dtype, logits_soft_cap=old_impl.logits_soft_cap, attn_type=old_impl.attn_type, kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/sparse_attn_worker.py` around lines 145 - 152, The construction of ModelOptSparseAttentionImpl currently hardcodes sliding_window=None which can change attention behavior; update the initializer in the replacement code to pass through old_impl.sliding_window (i.e., use sliding_window=old_impl.sliding_window) or, if non-None values are unsupported, explicitly check old_impl.sliding_window and raise an informative error before creating ModelOptSparseAttentionImpl; reference ModelOptSparseAttentionImpl, old_impl, and old_impl.sliding_window to locate and fix the code.
🧹 Nitpick comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)
111-119: Remove or wire_match_sparse_configto avoid dead-path drift
_match_sparse_configis currently unused, which makes behavior harder to reason about and can drift from real matching logic. Either use it in patching/selection flow or remove it until needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/sparse_attn_worker.py` around lines 111 - 119, The helper function _match_sparse_config is unused and creates dead-path drift; either remove this function or wire it into the sparse patch/selection flow by replacing the current pattern-matching logic with a call to _match_sparse_config(module_name, sparse_cfg) (or call it from wherever sparse layer configs are looked up) so that matching behavior is centralized; update any callers that currently duplicate pattern checks to use _match_sparse_config and remove dead duplicates if you choose to keep it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 81-83: The code currently returns whatever getattr(mtsa, cfg_name,
None) yields (cfg) and may hand back non-mapping objects; update the getter to
validate that cfg is a dict before returning (use isinstance(cfg, dict)) and
otherwise raise a clear error (e.g., ValueError) indicating that the symbol
named by cfg_name in modelopt.torch.sparsity.attention_sparsity must be a dict;
reference the getattr(mtsa, cfg_name, None) call and the cfg variable to locate
the change.
- Around line 92-106: The _load_sparse_config function currently trusts
arbitrary JSON from SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded
object and each layer_cfg: assert the top-level JSON is a dict, allowed
top-level keys are strings and either "calibration" or pattern names, and each
layer_cfg is a dict before applying defaults; enforce allowed keys (e.g.,
"method", "backend", "enable", numeric sparsity params) and bounds for numeric
fields (e.g., sparsity percentages 0–100, integer layer indices >=0, and
reasonable max limits) and reject or clamp out-of-range values, raising a clear
exception on invalid schema; keep the existing defaults
(method="triton_sparse_softmax", backend="triton", enable=True) for valid
entries and ensure sparse_cfg["default"] = {"enable": False} remains set.
---
Duplicate comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 145-152: The construction of ModelOptSparseAttentionImpl currently
hardcodes sliding_window=None which can change attention behavior; update the
initializer in the replacement code to pass through old_impl.sliding_window
(i.e., use sliding_window=old_impl.sliding_window) or, if non-None values are
unsupported, explicitly check old_impl.sliding_window and raise an informative
error before creating ModelOptSparseAttentionImpl; reference
ModelOptSparseAttentionImpl, old_impl, and old_impl.sliding_window to locate and
fix the code.
---
Nitpick comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 111-119: The helper function _match_sparse_config is unused and
creates dead-path drift; either remove this function or wire it into the sparse
patch/selection flow by replacing the current pattern-matching logic with a call
to _match_sparse_config(module_name, sparse_cfg) (or call it from wherever
sparse layer configs are looked up) so that matching behavior is centralized;
update any callers that currently duplicate pattern checks to use
_match_sparse_config and remove dead duplicates if you choose to keep it.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cf1d8866-f2b3-4adf-9b4a-eb396967c80e
📒 Files selected for processing (1)
examples/vllm_serve/sparse_attn_worker.py
| cfg = getattr(mtsa, cfg_name, None) | ||
| if cfg is not None: | ||
| return cfg |
There was a problem hiding this comment.
Validate preset object type before returning it
If SPARSE_ATTN_CFG matches a non-dict symbol in modelopt.torch.sparsity.attention_sparsity, cfg is returned as-is and later consumed as a mapping, which can crash at runtime. Add an explicit isinstance(cfg, dict) guard and fail fast with a clear error.
Proposed fix
cfg = getattr(mtsa, cfg_name, None)
if cfg is not None:
- return cfg
+ if not isinstance(cfg, dict):
+ raise ValueError(
+ f"Invalid sparse config preset '{cfg_name}': expected dict, got {type(cfg).__name__}."
+ )
+ return cfg🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/vllm_serve/sparse_attn_worker.py` around lines 81 - 83, The code
currently returns whatever getattr(mtsa, cfg_name, None) yields (cfg) and may
hand back non-mapping objects; update the getter to validate that cfg is a dict
before returning (use isinstance(cfg, dict)) and otherwise raise a clear error
(e.g., ValueError) indicating that the symbol named by cfg_name in
modelopt.torch.sparsity.attention_sparsity must be a dict; reference the
getattr(mtsa, cfg_name, None) call and the cfg variable to locate the change.
4644bf5 to
54079b8
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: Adds paged KV cache support to the ModelOpt Triton flash attention kernel, a vLLM sparse attention plugin, worker classes for vLLM integration, and GPU tests for the paged KV path.
Issues Found:
-
[Correctness] Backward pass silently broken for paged KV —
triton_fa.py,forward()savesk, vtoctxfor backward, but in paged mode these are dummy/empty tensors (e.g.,k_dummy = torch.empty(0, ...)from the vLLM plugin). If.backward()is ever called on paged-mode output, gradients for dK/dV will be computed against the dummy tensors, producing silently incorrect results. The backward should either raiseNotImplementedError("Backward not supported for paged KV cache")whenis_paged=True, or the limitation should be clearly documented in theattention()docstring. The current code just addsNonereturn placeholders for the 4 new args without any guard. -
[Correctness] Unused
import_pluginimport inplugins/__init__.py— The diff addsfrom modelopt.torch.utils import import_pluginbut it is never used in the file. The existing__init__.pydoesn't use this import, andvllm.pyimports directly from vLLM. This is dead code that should be removed, or if the intent was to useimport_pluginfor conditional vLLM import (as other plugin modules do), that wiring is missing. -
[Correctness]
_build_sparse_configfallback logic is confusing —sparse_attn_worker.py:78-86:getattr(mtsa, cfg_name, None)is tried first, but thenSPARSE_SOFTMAX_DEFAULTfalls through to the hardcoded_DEFAULT_SPARSE_CFGdict. Ifmtsaactually definesSPARSE_SOFTMAX_DEFAULTin the future, thegetattrpath would return it and the hardcoded default would never be used, leading to silent behavior divergence. Consider making the precedence explicit or removing the duplication. -
[Readability] Duplicated paged V-tile loading —
triton_fa.py_attn_fwdkernel:_load_paged_v_tileis called with identical arguments in two separate branches (skip-softmax path around line 476 and standard path around line 518). While Triton JIT constraints may require this, the duplicated 20-line call blocks are a readability concern. A comment explaining why the duplication is necessary would help. -
[Tests] No backward/gradient test for paged mode —
test_triton_fa_paged.pyonly tests forward correctness. Given that paged mode changes the autograd Function'sforwardsignature and backward is not updated to support paged KV, there should be at minimum a test asserting that backward raises an error (if a guard is added per issue #1), or this should be explicitly documented as inference-only. -
[Tests] No integration test for vLLM plugin —
ModelOptSparseAttentionImpl.forward()andModelOptSparseAttentionBackendhave no test coverage. These are the most integration-critical new classes. Even a mock-based unit test validating the metadata translation logic (cu_seqlens_q→b_start_loc,seq_lens→b_seq_len_k) would catch regressions. -
[Tests] Inconsistent
b_start_loc_khandling across tests —test_paged_matches_contiguouspasses explicitb_start_loc_k=locs_k, whiletest_paged_no_nanomits it (relying on the dummy-zeros fallback). Both should use the same calling convention to avoid masking bugs in the fallback path. -
[Readability]
if threshold:truthiness check —sparse_attn_worker.py:146:if threshold:evaluatesFalsefor bothNoneand0.0. While0.0correctly means "disabled", this is subtle.if threshold is not Nonewould be clearer about intent (let the kernel handle the0.0case).
Suggestions:
- Consider adding
requires_grad=Falsetok_dummyandv_dummyin the vLLM plugin to make the inference-only intent explicit and catch accidental backward calls early. - The
_DEFAULT_SPARSE_CFGhardcoded in the worker could referencemtsaconstants instead, reducing drift risk. test_paged_decodeusesq_flatwith shape[batch, num_heads, head_dim](3D) rather than the expected[total_q_tokens, num_heads, head_dim]. This works becausebatch * 1 = batchtokens, but the shape semantics are confusing — usingq_flat.reshape(batch, num_heads, head_dim)explicitly or adding a comment would clarify.
Overall Assessment: The core kernel extension (paged KV tile loaders + IS_PAGED branching) is well-structured and the tests verify forward correctness across multiple configurations. The main blocking concern is the silent backward-pass breakage for paged mode — this needs at minimum a guard or documentation since the function is part of the public attention() API with autograd support. The unused import is a minor cleanup. The vLLM plugin lacks test coverage but is in examples/ territory so is less critical.
Implements a ModelOpt sparse attention plugin for diffusers WAN models, building on the triton_fa kernel infrastructure from PR #1127. New files: - modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py - ModelOptWanAttnProcessor: replaces WanAttnProcessor, calls triton_fa.attention() directly with BSND->varlen conversion. Supports I2V cross-attention path and N:M/skip-softmax sparsity. - WanSparseAttentionModule: subclasses SparseAttentionModule, installs the Triton processor and syncs enabled state on each forward. - register_wan_sparse_attention(): plugin callback auto-registered in CUSTOM_MODEL_PLUGINS; fires during mtsa.sparsify(). Updated files: - plugins/__init__.py: lazy-import diffusers plugin via import_plugin() - config.py: add "diffusers_triton" to validate_backend whitelist - conversion.py: skip HF attn registration for "diffusers_triton" backend - wan2_sage_attention.py: add triton-sparse and triton-skip kernel options backed by mtsa.sparsify() with diffusers_triton backend Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Implements a ModelOpt sparse attention plugin for diffusers WAN models, building on the triton_fa kernel infrastructure from PR #1127. New files: - modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py - ModelOptWanAttnProcessor: replaces WanAttnProcessor, calls triton_fa.attention() directly with BSND->varlen conversion. Supports I2V cross-attention path and N:M/skip-softmax sparsity. - WanSparseAttentionModule: subclasses SparseAttentionModule, installs the Triton processor and syncs enabled state on each forward. - register_wan_sparse_attention(): plugin callback auto-registered in CUSTOM_MODEL_PLUGINS; fires during mtsa.sparsify(). Updated files: - plugins/__init__.py: lazy-import diffusers plugin via import_plugin() - config.py: add "diffusers_triton" to validate_backend whitelist - conversion.py: skip HF attn registration for "diffusers_triton" backend - wan2_sage_attention.py: add triton-sparse and triton-skip kernel options backed by mtsa.sparsify() with diffusers_triton backend Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
54079b8 to
bbac896
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Around line 81-87: The current conditional treats a quant-only setup as the
"none set" case, causing an incorrect warning; update the branch logic around
has_quant, has_sparse and worker_cls so quant-only is handled explicitly: add an
elif has_quant branch that sets worker_cls to a suitable quant-only class (e.g.,
"sparse_attn_worker.QuantWorker") and only use the warning in the final else
when neither has_sparse nor has_quant is true, or change the warning message to
accurately reflect the detected state; refer to the boolean flags has_quant and
has_sparse and the variable worker_cls to locate and modify the logic.
- Around line 31-36: Move the hard imports of optional packages (uvloop, vllm
and the vllm.entrypoints imports run_server and make_arg_parser) out of module
scope and into the main() function (or a helper invoked by main), using the
project's import_plugin() helper or try/except ImportError to lazily load them
and provide a clear error message if missing; update any top-level references to
run_server/make_arg_parser to use the locally imported names inside main() so
the module can be imported even when those extras are not installed.
- Around line 72-76: The PYTHONPATH construction currently appends repo_root
using a hardcoded ":" which is non-portable and creates a leading separator when
PYTHONPATH is empty; update the logic that sets os.environ["PYTHONPATH"] to use
os.pathsep and only prepend the separator when an existing value is non-empty
(i.e., read current = os.environ.get("PYTHONPATH", "") and set it to current +
os.pathsep + repo_root if current else repo_root) so that repo_root, sys.path
insertion, and environment modification remain correct and platform-independent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b08b8b7f-41eb-4c7a-bfee-feeb508721ac
📒 Files selected for processing (2)
examples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.py
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/vllm_serve/sparse_attn_worker.py
| import uvloop | ||
| import vllm | ||
| from packaging import version | ||
| from vllm.entrypoints.openai.api_server import run_server | ||
| from vllm.entrypoints.openai.cli_args import make_arg_parser | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and read the file in question
cd examples/vllm_serve && ls -la && wc -l vllm_serve_sparse_attn.pyRepository: NVIDIA/Model-Optimizer
Length of output: 771
🏁 Script executed:
# Read the file with line numbers to see the actual imports and structure
head -100 examples/vllm_serve/vllm_serve_sparse_attn.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 4153
🏁 Script executed:
# Search for import_plugin usage in the codebase to understand the pattern
rg "import_plugin" --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 21147
🏁 Script executed:
# Check pyproject.toml for optional extras definition
grep -A 20 "optional-dependencies\|extras" pyproject.toml | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 1570
Move vLLM and uvloop imports into main() function to handle missing optional dependencies gracefully.
Lines 31–35 perform hard module-level imports of uvloop, vllm, and related entrypoints. This breaks the module import in environments without these optional packages installed. Relocate these imports inside main() with appropriate error handling, or use import_plugin() as established throughout the codebase for optional integrations.
As per coding guidelines: "**/*.py: Use optional dependencies gated by install extras; avoid hard imports at module level for optional features" and "Load optional integrations lazily via import_plugin()."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 31 - 36, Move the
hard imports of optional packages (uvloop, vllm and the vllm.entrypoints imports
run_server and make_arg_parser) out of module scope and into the main() function
(or a helper invoked by main), using the project's import_plugin() helper or
try/except ImportError to lazily load them and provide a clear error message if
missing; update any top-level references to run_server/make_arg_parser to use
the locally imported names inside main() so the module can be imported even when
those extras are not installed.
meenchen
left a comment
There was a problem hiding this comment.
Can you also update the README of the example for sparse attention?
| @@ -0,0 +1,226 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Can this worker be merged with fakequnt_worker? Ideally, we would like a unified entry point for both quantization and sparsity, so we can simulate quantization and sparisty and the same time
There was a problem hiding this comment.
SparseQuantWorker in sparse_attn_worker.py already supports this. Currently, we have three workers:
FakeQuantWorkerinfakequant_worker.py(quantization only)SparseAttnWorkerinsparse_attn_worker.py(sparsity only)SparseQuantWorkerinsparse_attn_worker.py(quantization + sparsity) — this is already the unified implementation
We can consolidate these three workers into a single unified worker, such as ModelOptWorker, in a follow-up PR.
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review. Several critical items from prior rounds are still unresolved:
Unresolved critical:
-
Backward pass silently incorrect in paged mode (flagged by cjluo-nv and CodeRabbit).
_Attention.forward()doesctx.save_for_backward(q, k, v, ...)and the backward kernels indexK/Vcontiguously usingb_start_loc_k. In paged mode, callers (plugins/vllm.py) passk_dummy = torch.empty(0, num_kv_heads, head_size, ...)andb_start_loc_k=None(then zero-filled). If anything ever calls.backward()on paged output, backward will either OOB-load or compute garbage gradients silently. There is still no guard. At minimum addif is_paged and (q.requires_grad or k.requires_grad or v.requires_grad): raise NotImplementedError(...)inforward, and document this limitation in theattention()docstring. -
No tests for the new vLLM plugin (
ModelOptSparseAttentionImpl.forward,ModelOptSparseAttentionBackend). The paged kernel has good tests, but the integration-critical layer (metadata translationquery_start_loc→b_start_loc,seq_lens→b_seq_len_k, KV cache unbind / page_size derivation) is untested. Given the "workers disagree on kv_cache axis order" bug found last round, a small mock-based unit test here is important. -
README not updated as meenchen requested.
Unresolved minor / cleanup:
4. modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py adds from modelopt.torch.utils import import_plugin but never uses it. Either wire vllm.py in via import_plugin(".vllm", ...) (matching other plugin modules) or remove the import. Currently plugins/vllm.py is only reachable from the example.
5. vllm_serve_sparse_attn.py line 75: os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" is not portable and creates a leading : when PYTHONPATH is unset. Use os.pathsep and conditional join.
6. vllm_serve_sparse_attn.py warning branch (line 85–86) prints "No SPARSE_ATTN_CFG or QUANT_CFG set" even when QUANT_CFG is set but sparse is not; the message is inaccurate for the quant-only case.
7. sparse_attn_worker.py _build_sparse_config has a dead branch: SPARSE_SOFTMAX_DEFAULT is exported from mtsa.config, so getattr(mtsa, "SPARSE_SOFTMAX_DEFAULT", None) already returns the canonical preset and the local _DEFAULT_SPARSE_CFG fallback can never be reached. Either remove the local fallback or explain why it's intentionally divergent.
8. sparse_attn_worker.py line 151: still passes sliding_window=None to ModelOptSparseAttentionImpl.__init__ and then fixes up with new_impl.sliding_window = old_impl.sliding_window afterward. That relies on FlashAttentionImpl.__init__ not using the value for anything else (e.g., internal config). Safer to pass old_impl.sliding_window directly through the constructor, or assert old_impl.sliding_window is unset.
Design review (required per complexity gate): this is PR 4/n of a subsystem landed incrementally, and the kernel extension + integration style (impl swap rather than a new vLLM attention backend registered via entry points) is a reasonable reuse of FlashAttentionImpl. Good: ModelOptSparseAttentionImpl inherits to reuse __init__ and do_kv_cache_update, only overriding forward. That keeps the scope focused. No second system duplicating existing in-repo sparse-attention infra is introduced.
Please address #1 (backward guard) and #3 (README) before merging; #2 (plugin tests) is strongly recommended given the axis-order bug caught last round.
| skip_softmax_raw_threshold, | ||
| measure_sparsity, | ||
| k_cache, | ||
| v_cache, |
There was a problem hiding this comment.
Bot comment.
Paged mode is forward-only, but there is no guard: ctx.save_for_backward(q, k, v, ...) stores whatever the caller passed (dummy torch.empty(0, num_kv_heads, head_size) from plugins/vllm.py), and backward then indexes K/V contiguously through b_start_loc_k. A stray .backward() on paged output will OOB-load or produce silently wrong gradients.
Suggest adding right after is_paged = k_cache is not None:
if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad):
raise NotImplementedError(
"Paged KV cache path is forward-only; backward is not implemented."
)and document the limitation in the attention() docstring.
|
|
||
| """Plugins for sparse attention integration with various frameworks.""" | ||
|
|
||
| from modelopt.torch.utils import import_plugin |
There was a problem hiding this comment.
Bot comment.
import_plugin is imported here but never used (and vllm.py is not wired in via the usual import_plugin(".vllm", ...) pattern). Either register vllm.py through this mechanism (matching other plugin modules like huggingface) or drop this dead import.
| sparsity_n = layer_cfg.get("sparsity_n", 0) | ||
| if sparsity_n > 0: | ||
| sparse_kw["sparsity_n"] = sparsity_n | ||
| sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) |
There was a problem hiding this comment.
Bot comment.
Still passing sliding_window=None into __init__ and patching new_impl.sliding_window afterward. That only works if FlashAttentionImpl.__init__ doesn't use the value internally (e.g. for backend selection or capability checks). Safer to pass old_impl.sliding_window directly — and if the comment about "can't reverse it" is true, at least assert the old impl's value is either None or the raw tuple so future vLLM versions don't silently break this workaround.
| if cfg is not None: | ||
| return cfg | ||
| # Use built-in default if name matches | ||
| if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"): |
There was a problem hiding this comment.
Bot comment.
Dead fallback: SPARSE_SOFTMAX_DEFAULT is exported from modelopt.torch.sparsity.attention_sparsity.config (and re-exported from mtsa), so getattr(mtsa, 'SPARSE_SOFTMAX_DEFAULT', None) will always return the canonical preset — _DEFAULT_SPARSE_CFG is unreachable for that name. Either delete _DEFAULT_SPARSE_CFG or document why it intentionally differs from mtsa.SPARSE_SOFTMAX_DEFAULT (e.g. dense_window_size: 1 vs 64).
| repo_root = str(Path(__file__).resolve().parent) | ||
| if repo_root not in sys.path: | ||
| sys.path.insert(0, repo_root) | ||
| os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" |
There was a problem hiding this comment.
Bot comment.
Non-portable (: is POSIX-only) and prepends an empty path component when PYTHONPATH is unset. Use:
current = os.environ.get("PYTHONPATH")
os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root| elif has_sparse: | ||
| worker_cls = "sparse_attn_worker.SparseAttnWorker" | ||
| else: | ||
| print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.") |
There was a problem hiding this comment.
Bot comment.
This branch is hit when QUANT_CFG/KV_QUANT_CFG is set but sparse env vars are not, and the message "No SPARSE_ATTN_CFG or QUANT_CFG set" is then false. Either route quant-only to fakequant_worker.FakeQuantWorker here, or make the message say "No SPARSE_ATTN_CFG set — running standard vLLM (quant-only is not handled by this launcher)".
There was a problem hiding this comment.
🧹 Nitpick comments (1)
examples/vllm_serve/vllm_serve_sparse_attn.py (1)
35-45: ⚡ Quick winConsider moving optional imports into
main()for graceful error handling.Lines 35–39 perform hard imports of
uvloop,vllm, and vLLM entrypoints at module level. If these optional packages are not installed, importing this module will fail immediately. Moving these imports intomain()with a try/except block would allow for clearer error messages and better align with best practices for optional dependencies.♻️ Proposed refactor
-import uvloop -import vllm -from packaging import version -from vllm.entrypoints.openai.api_server import run_server -from vllm.entrypoints.openai.cli_args import make_arg_parser - -vllm_version = version.parse(vllm.__version__) -if vllm_version <= version.parse("0.11.0"): - from vllm.utils import FlexibleArgumentParser -else: - from vllm.utils.argparse_utils import FlexibleArgumentParser - - def main(): """Launch vLLM with sparse attention worker.""" + try: + import uvloop + import vllm + from packaging import version + from vllm.entrypoints.openai.api_server import run_server + from vllm.entrypoints.openai.cli_args import make_arg_parser + except ImportError as e: + print(f"Error: {e}") + print("Please install vLLM and uvloop: pip install vllm uvloop") + sys.exit(1) + + vllm_version = version.parse(vllm.__version__) + if vllm_version <= version.parse("0.11.0"): + from vllm.utils import FlexibleArgumentParser + else: + from vllm.utils.argparse_utils import FlexibleArgumentParser + parser = FlexibleArgumentParser(description="vLLM model server with sparse attention")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 35 - 45, The module-level imports of optional dependencies (uvloop, vllm, run_server, make_arg_parser, and FlexibleArgumentParser) should be moved into the main entrypoint (e.g., main()) and wrapped in a try/except so missing packages raise a clear, user-friendly error instead of failing on import; update the code to import uvloop, vllm, vllm.entrypoints.openai.api_server.run_server, vllm.entrypoints.openai.cli_args.make_arg_parser, and the FlexibleArgumentParser variant inside main(), detect vllm version as before, and on ImportError/ModuleNotFoundError print a concise message indicating which package is missing and how to install it, then exit gracefully.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Around line 35-45: The module-level imports of optional dependencies (uvloop,
vllm, run_server, make_arg_parser, and FlexibleArgumentParser) should be moved
into the main entrypoint (e.g., main()) and wrapped in a try/except so missing
packages raise a clear, user-friendly error instead of failing on import; update
the code to import uvloop, vllm, vllm.entrypoints.openai.api_server.run_server,
vllm.entrypoints.openai.cli_args.make_arg_parser, and the FlexibleArgumentParser
variant inside main(), detect vllm version as before, and on
ImportError/ModuleNotFoundError print a concise message indicating which package
is missing and how to install it, then exit gracefully.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e2dd2197-8fab-46c3-a984-bd498244f350
📒 Files selected for processing (13)
examples/vllm_serve/README.mdexamples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/kernels/common/attention/triton_fa.pymodelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
- examples/vllm_serve/README.md
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
252d8fb to
20034a2
Compare
What does this PR do?
Type of change: ?
New feature. Add vLLM integration for ModelOpt sparse attention with paged KV cache support.
Extends the Triton flash attention kernel (
triton_fa.py) with paged KV cache support. KV cache can be read directly from vLLM's non-contiguous paged cache viablock_table lookup, avoiding expensive gather-to-contiguous copies.SparseVLLMAttentionwraps vLLM's Attention layer. It lets vLLM write KV to its paged cache, then calls the ModelOpt Triton kernel withk_cache,v_cache,block_tablefor both prefill and decode.SparseAttnWorkerpatches vLLM attention modules at model load time.SparseQuantWorkercombines quantization+sparse attention. Worker selection is automatic based on env vars (SPARSE_ATTN_CFG, QUANT_CFG).vllm_serve_sparse_attn.pylaunches a vLLM OpenAI-compatible server with sparse attention enabled.Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Chores