Add MLA support to vLLM patch#1499
Conversation
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
📝 WalkthroughWalkthroughThis pull request adds comprehensive sparse-attention support for vLLM by extending Triton flash-attention kernels with paged KV-cache and per-sequence skip-softmax thresholds, implementing a vLLM sparse-attention plugin, providing worker classes for model patching, and delivering a server launcher with extensive test coverage and configuration updates. ChangesvLLM Sparse Attention Infrastructure
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/kernels/common/attention/triton_fa.py (2)
894-953:⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy liftResolve merge conflict in skip-softmax threshold computation.
Two different threshold configuration models:
- HEAD:
skip_softmax_raw_threshold(raw kernel value) vsskip_softmax_threshold(lambda)- Incoming: mutually exclusive fixed mode vs scale-factor mode with prefill/decode phase detection
Decide on the threshold configuration API and resolve accordingly.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/kernels/common/attention/triton_fa.py` around lines 894 - 953, Resolve the merge conflict by unifying both configurations: support the raw kernel override (skip_softmax_raw_threshold) while keeping the mutually-exclusive fixed vs scale-factor API (skip_softmax_threshold and skip_softmax_threshold_scale_prefill/decode). Implement logic that first checks skip_softmax_raw_threshold (if not None) and sets apply_skip=True and skip_threshold_log2=skip_softmax_raw_threshold and skip_threshold_scale_log2=0.0; otherwise determine fixed_mode, scale_prefill_set, scale_decode_set and assert mutual exclusivity, compute is_decode from max_input_len, set active_scale and use_scale_factor, then set apply_skip, skip_threshold_log2 (math.log2(skip_softmax_threshold)*sm_scale for fixed_mode or 0.0 for scale-factor), and skip_threshold_scale_log2 (0.0 for fixed_mode or math.log2(active_scale)*sm_scale for scale-factor), with the fallback of apply_skip=False and both skip_threshold_* set to 0.0.
370-406:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winResolve merge conflict in skip-softmax implementation.
Two different implementations of skip-softmax tile-skipping logic are present:
- HEAD: delegates to
_skip_softmax_decisionhelper with sparsity counters- Incoming: inline BLASST-style tile skipping with scale-factor support
Choose one implementation or merge both features if both sparsity measurement and scale-factor mode are needed.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/kernels/common/attention/triton_fa.py` around lines 370 - 406, There is a merge conflict between two skip-softmax implementations: one that calls the helper _skip_softmax_decision(scores, row_max, SKIP_THRESHOLD_LOG2, Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY) and an inline BLASST-style implementation using tile_row_max, can_skip, skip_threshold_log2 and USE_SKIP_SCALE_FACTOR; pick one coherent approach or merge them by keeping the BLASST inline logic (tile_row_max, can_skip, skip_tile) and then update the sparsity counters (Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY) in the same block so metrics remain collected, or replace the inline block with a single call to _skip_softmax_decision and ensure the threshold symbol is consistent (SKIP_THRESHOLD_LOG2 vs skip_threshold_log2) and USE_SKIP_SCALE_FACTOR behavior is respected; update only the code paths referenced by _skip_softmax_decision or the inline variables (scores, row_max, tile_row_max, skip_tile, SKIP_THRESHOLD_LOG2/skip_threshold_log2, Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY, USE_SKIP_SCALE_FACTOR) so the final file contains one unambiguous implementation and consistent variable names.
🧹 Nitpick comments (2)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py (1)
290-337: ⚡ Quick winStrengthen decode test with correctness parity (not only finiteness).
Current decode coverage checks shape/NaN, but not correctness against a contiguous reference, so silent numerical regressions can slip through. Add a contiguous decode baseline and assert closeness; also add an Inf check for consistency with other tests.
Suggested patch
@@ def test_paged_decode(self): @@ - out = attention( + out_contig = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + ) + + out = attention( q_flat, k_flat, v_flat, @@ ) + torch.testing.assert_close(out, out_contig, rtol=1e-2, atol=1e-2) assert out.shape == q_flat.shape assert not torch.isnan(out).any(), "NaN in paged decode output" + assert not torch.isinf(out).any(), "Inf in paged decode output"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py` around lines 290 - 337, test_paged_decode only checks shape/NaN but not numeric parity with a contiguous decode; call the same attention(...) without the paged cache (pass k_flat/v_flat and omit k_cache/v_cache/block_table/page_size, using the same q_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k and softmax_scale) to produce a baseline output, then assert the paged output from attention(...) is close to the baseline (use torch.allclose or torch.isclose with reasonable rtol/atol) and also assert torch.isfinite on both outputs; update the test_paged_decode function (references: test_paged_decode, attention, _scatter_to_paged_cache, q_flat, k_flat, v_flat, k_cache, v_cache, block_table) to include these comparisons.modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py (1)
107-115: 💤 Low valueConsider consolidating logging output.
The message is emitted via both
logger.warning()andprint(). Since the logger should handle output routing, the explicitprint()may be redundant unless there's a specific need to ensure console visibility independent of logging configuration.♻️ Simplify by removing print
logger.warning(msg) - print(msg, flush=True) _LOGGED_CONFIGS.add(kw_key)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py` around lines 107 - 115, The code logs the sparse Triton kernel message twice by calling logger.warning(msg) and print(msg); remove the redundant print() call so all output is routed through the configured logger (leave logger.warning(msg) and the subsequent _LOGGED_CONFIGS.add(kw_key) intact), referencing the same variables used in the message construction (variant_str, is_prefill, layer, sparse_kw, kw_key) to locate the spot in vllm.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 148-152: The check that sets sparse_kw["skip_softmax_threshold"]
incorrectly treats 0.0 as falsy; change the conditional that currently uses "if
threshold:" to explicitly test for non-None (e.g., "if threshold is not None:")
so that a valid zero threshold from layer_cfg or env_threshold is preserved;
update the block that reads threshold from layer_cfg and env_threshold
(variables: threshold, layer_cfg, env_threshold) and assigns to sparse_kw to use
this non-None check.
- Around line 107-108: Replace the bare file open used to load calib_cfg with a
robust read that specifies encoding and handles errors: change the block using
"with open(path) as f: calib_cfg = json.load(f)" to "with open(path, 'r',
encoding='utf-8') as f: ..." and wrap it in a try/except that catches
FileNotFoundError, json.JSONDecodeError, and generic OSError to either log a
clear error (including the path and exception) or re-raise a descriptive
exception; ensure the variable name calib_cfg and the path variable remain the
same so callers are unaffected.
In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Around line 78-81: The PYTHONPATH concatenation currently always prepends ":"
which yields a leading empty path entry when PYTHONPATH is empty; update the
logic around repo_root/sys.path.insert and os.environ["PYTHONPATH"] to read the
existing value (os.environ.get("PYTHONPATH")), and set os.environ["PYTHONPATH"]
to either "<existing> + os.pathsep + repo_root" if existing is non-empty or just
"repo_root" if not, using os.pathsep instead of a hardcoded ":" to be portable
and avoid introducing an empty path entry.
In `@modelopt/torch/kernels/__init__.py`:
- Around line 16-55: Remove the leftover merge conflict markers and keep the
completed initialization logic from the incoming change: use the descriptive
module docstring, import torch and import_plugin, define IS_AVAILABLE=False and
attention/attention_with_lse/register_triton_attention=None, then if
torch.cuda.is_available() use import_plugin("triton") to import .triton_fa
attention and attention_with_lse, set IS_AVAILABLE=True, and inside an
import_plugin("transformers") block import and assign register_triton_attention
from .hf_triton_attention; finally ensure __all__ contains "IS_AVAILABLE",
"attention", "attention_with_lse", and "register_triton_attention".
In `@modelopt/torch/kernels/common/attention/triton_fa.py`:
- Around line 1323-1325: There is a merge conflict in the module exports:
reconcile the __all__ to export the correct public API by including the
functions consumers need—ensure LOG2E and _apply_mask are exported only if they
are intended public symbols, and add the new API attention_with_lse if the
incoming branch introduced it; update the __all__ to something like ["LOG2E",
"_apply_mask", "attention", "attention_with_lse"] (omit LOG2E/_apply_mask if
they are internal), and keep the symbol names exactly as defined in this file
(LOG2E, _apply_mask, attention, attention_with_lse) so imports work correctly.
- Around line 1309-1320: The call to _Attention.apply has leftover
merge-conflict argument names; update the argument list to match the resolved
forward signature by removing skip_softmax_raw_threshold and measure_sparsity
and passing skip_softmax_threshold_scale_prefill and
skip_softmax_threshold_scale_decode in their place (preserving the order used by
_Attention.apply), and ensure other arguments like k_cache, v_cache,
block_table, and page_size remain unchanged so the call matches the
implementation signature.
- Around line 1005-1012: The merge left conflicting argument names in the
_attn_fwd kernel call: remove the leftover
Sparsity_total/Sparsity_skipped/MEASURE_SPARSITY branch and replace it with the
resolved kernel parameters USE_SKIP_SCALE_FACTOR and SKIP_THRESHOLD_SCALE_LOG2
(passing the local variables use_scale_factor and skip_threshold_scale_log2
respectively) so the _attn_fwd invocation matches the final kernel signature;
ensure no duplicate or undefined symbols remain and that the argument
order/keywords align with the kernel's definition.
- Around line 1182-1193: The backward gradient tuple in triton_fa.py must match
the resolved forward parameter list: replace the two outdated placeholders
(skip_softmax_raw_threshold, measure_sparsity) with the correct entries for
skip_softmax_threshold_scale_prefill and skip_softmax_threshold_scale_decode and
ensure the ordering matches the forward signature before the existing k_cache,
v_cache, block_table, page_size placeholders so the number and order of None
entries align with the forward method.
- Around line 238-247: The _attn_fwd kernel signature contains git conflict
markers and missing parameters; remove the conflict markers and merge both
branches so the signature includes SKIP_THRESHOLD_LOG2: tl.constexpr,
USE_SKIP_SCALE_FACTOR: tl.constexpr, SKIP_THRESHOLD_SCALE_LOG2: tl.constexpr,
plus the runtime sparsity measurement params Sparsity_total, Sparsity_skipped
and MEASURE_SPARSITY (keeping the existing types/comments), and update the
inline comments so SKIP_THRESHOLD_LOG2 is documented as the fixed-mode
pre-scaled log2(lambda) and SKIP_THRESHOLD_SCALE_LOG2/USE_SKIP_SCALE_FACTOR are
documented for scale-factor mode; ensure no duplicate names and that the
parameter order matches callers of _attn_fwd.
- Around line 856-862: The _Attention.forward signature contains leftover
merge-conflict parameters; update the function signature to match the resolved
kernel signature (the parameters shown around lines 238-247 in the diff) by
removing the obsolete names (e.g., skip_softmax_raw_threshold, measure_sparsity)
and replacing them with the finalized parameters (e.g.,
skip_softmax_threshold_scale_prefill, skip_softmax_threshold_scale_decode) so
the forward method parameters exactly mirror the resolved triton kernel
signature.
- Around line 1214-1220: The public attention() signature contains unresolved
merge markers: remove the conflict markers and settle on a single API; choose
whether to keep skip_softmax_raw_threshold or the new pair
skip_softmax_threshold_scale_prefill and skip_softmax_threshold_scale_decode (do
not leave both variants). For minimal breakage prefer the new explicit pair plus
retain measure_sparsity if it was intended: update the attention(...) parameter
list to include skip_softmax_threshold_scale_prefill: float | None = None,
skip_softmax_threshold_scale_decode: float | None = None, measure_sparsity: bool
= False (and remove skip_softmax_raw_threshold and all <<<<>>> markers), then
update the attention() docstring and any exports/tests that reference the old
name to match the chosen signature.
- Around line 1409-1454: The _attn_fwd kernel call is missing the sm_scale
argument expected immediately after qk_scale; update the invocation of _attn_fwd
to pass the sm_scale parameter (the same value/type used elsewhere where the
kernel is called or from the local variable named sm_scale) between qk_scale and
b_start_loc so the argument order matches the kernel signature used by _attn_fwd
and avoids parameter misalignment.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 728-732: The __all__ export list contains unresolved merge markers
and is currently choosing between VSA_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT;
remove the conflict markers and export the correct names — include both
VSA_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT in the __all__ list (no merge
markers, no duplicates) so both symbols are publicly exported from the module.
In `@pyproject.toml`:
- Around line 52-53: Remove "uvloop>=0.22.1" from the core pyproject.toml
dependencies and add it to the examples/vllm_serve/requirements.txt since uvloop
is only used by the example scripts (refer to examples/vllm_serve/). For
"vllm==0.20.1", move it out of core dependencies into
[project.optional-dependencies] (or an extras group) and document that it is
required only for the vLLM quantization plugin invoked via
import_plugin("vllm"); also relax the pin to a range such as ~=0.20.1 or
>=0.20.1,<0.21 to avoid a fragile exact-match dependency.
---
Outside diff comments:
In `@modelopt/torch/kernels/common/attention/triton_fa.py`:
- Around line 894-953: Resolve the merge conflict by unifying both
configurations: support the raw kernel override (skip_softmax_raw_threshold)
while keeping the mutually-exclusive fixed vs scale-factor API
(skip_softmax_threshold and skip_softmax_threshold_scale_prefill/decode).
Implement logic that first checks skip_softmax_raw_threshold (if not None) and
sets apply_skip=True and skip_threshold_log2=skip_softmax_raw_threshold and
skip_threshold_scale_log2=0.0; otherwise determine fixed_mode,
scale_prefill_set, scale_decode_set and assert mutual exclusivity, compute
is_decode from max_input_len, set active_scale and use_scale_factor, then set
apply_skip, skip_threshold_log2 (math.log2(skip_softmax_threshold)*sm_scale for
fixed_mode or 0.0 for scale-factor), and skip_threshold_scale_log2 (0.0 for
fixed_mode or math.log2(active_scale)*sm_scale for scale-factor), with the
fallback of apply_skip=False and both skip_threshold_* set to 0.0.
- Around line 370-406: There is a merge conflict between two skip-softmax
implementations: one that calls the helper _skip_softmax_decision(scores,
row_max, SKIP_THRESHOLD_LOG2, Sparsity_total, Sparsity_skipped,
MEASURE_SPARSITY) and an inline BLASST-style implementation using tile_row_max,
can_skip, skip_threshold_log2 and USE_SKIP_SCALE_FACTOR; pick one coherent
approach or merge them by keeping the BLASST inline logic (tile_row_max,
can_skip, skip_tile) and then update the sparsity counters (Sparsity_total,
Sparsity_skipped, MEASURE_SPARSITY) in the same block so metrics remain
collected, or replace the inline block with a single call to
_skip_softmax_decision and ensure the threshold symbol is consistent
(SKIP_THRESHOLD_LOG2 vs skip_threshold_log2) and USE_SKIP_SCALE_FACTOR behavior
is respected; update only the code paths referenced by _skip_softmax_decision or
the inline variables (scores, row_max, tile_row_max, skip_tile,
SKIP_THRESHOLD_LOG2/skip_threshold_log2, Sparsity_total, Sparsity_skipped,
MEASURE_SPARSITY, USE_SKIP_SCALE_FACTOR) so the final file contains one
unambiguous implementation and consistent variable names.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 107-115: The code logs the sparse Triton kernel message twice by
calling logger.warning(msg) and print(msg); remove the redundant print() call so
all output is routed through the configured logger (leave logger.warning(msg)
and the subsequent _LOGGED_CONFIGS.add(kw_key) intact), referencing the same
variables used in the message construction (variant_str, is_prefill, layer,
sparse_kw, kw_key) to locate the spot in vllm.py.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py`:
- Around line 290-337: test_paged_decode only checks shape/NaN but not numeric
parity with a contiguous decode; call the same attention(...) without the paged
cache (pass k_flat/v_flat and omit k_cache/v_cache/block_table/page_size, using
the same q_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k and
softmax_scale) to produce a baseline output, then assert the paged output from
attention(...) is close to the baseline (use torch.allclose or torch.isclose
with reasonable rtol/atol) and also assert torch.isfinite on both outputs;
update the test_paged_decode function (references: test_paged_decode, attention,
_scatter_to_paged_cache, q_flat, k_flat, v_flat, k_cache, v_cache, block_table)
to include these comparisons.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: a1b4948d-004e-4822-9b51-e4f09740b68c
📒 Files selected for processing (11)
examples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/kernels/__init__.pymodelopt/torch/kernels/common/attention/triton_fa.pymodelopt/torch/quantization/plugins/vllm.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pypyproject.tomltests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
| with open(path) as f: | ||
| calib_cfg = json.load(f) |
There was a problem hiding this comment.
Add error handling and specify encoding for file operations.
The file open operation lacks error handling and doesn't specify an encoding, which can cause issues on systems with non-UTF-8 defaults.
🛡️ Proposed fix
def _load_sparse_config(path: str) -> dict:
"""Load offline calibration config JSON."""
- with open(path) as f:
- calib_cfg = json.load(f)
+ try:
+ with open(path, encoding="utf-8") as f:
+ calib_cfg = json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError) as e:
+ raise ValueError(f"Failed to load sparse config from {path}: {e}") from e🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/vllm_serve/sparse_attn_worker.py` around lines 107 - 108, Replace
the bare file open used to load calib_cfg with a robust read that specifies
encoding and handles errors: change the block using "with open(path) as f:
calib_cfg = json.load(f)" to "with open(path, 'r', encoding='utf-8') as f: ..."
and wrap it in a try/except that catches FileNotFoundError,
json.JSONDecodeError, and generic OSError to either log a clear error (including
the path and exception) or re-raise a descriptive exception; ensure the variable
name calib_cfg and the path variable remain the same so callers are unaffected.
| threshold = layer_cfg.get("skip_softmax_threshold") | ||
| if env_threshold is not None: | ||
| threshold = env_threshold | ||
| if threshold: | ||
| sparse_kw["skip_softmax_threshold"] = float(threshold) |
There was a problem hiding this comment.
Fix threshold validation to handle zero values.
The condition if threshold: evaluates to False when threshold=0.0, which would skip a valid zero threshold. Use if threshold is not None: instead.
🐛 Proposed fix
threshold = layer_cfg.get("skip_softmax_threshold")
if env_threshold is not None:
threshold = env_threshold
- if threshold:
+ if threshold is not None:
sparse_kw["skip_softmax_threshold"] = float(threshold)
return sparse_kw🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/vllm_serve/sparse_attn_worker.py` around lines 148 - 152, The check
that sets sparse_kw["skip_softmax_threshold"] incorrectly treats 0.0 as falsy;
change the conditional that currently uses "if threshold:" to explicitly test
for non-None (e.g., "if threshold is not None:") so that a valid zero threshold
from layer_cfg or env_threshold is preserved; update the block that reads
threshold from layer_cfg and env_threshold (variables: threshold, layer_cfg,
env_threshold) and assigns to sparse_kw to use this non-None check.
| repo_root = str(Path(__file__).resolve().parent) | ||
| if repo_root not in sys.path: | ||
| sys.path.insert(0, repo_root) | ||
| os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" |
There was a problem hiding this comment.
Fix PYTHONPATH concatenation to avoid empty path entry.
When PYTHONPATH is empty or unset, the current code creates a leading : which adds an empty entry to the path, potentially causing Python to search the current directory. Use conditional logic to handle the empty case.
🐛 Proposed fix
repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
- os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"
+ existing = os.environ.get("PYTHONPATH", "")
+ os.environ["PYTHONPATH"] = f"{repo_root}:{existing}" if existing else repo_root📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| repo_root = str(Path(__file__).resolve().parent) | |
| if repo_root not in sys.path: | |
| sys.path.insert(0, repo_root) | |
| os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" | |
| repo_root = str(Path(__file__).resolve().parent) | |
| if repo_root not in sys.path: | |
| sys.path.insert(0, repo_root) | |
| existing = os.environ.get("PYTHONPATH", "") | |
| os.environ["PYTHONPATH"] = f"{repo_root}:{existing}" if existing else repo_root |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/vllm_serve/vllm_serve_sparse_attn.py` around lines 78 - 81, The
PYTHONPATH concatenation currently always prepends ":" which yields a leading
empty path entry when PYTHONPATH is empty; update the logic around
repo_root/sys.path.insert and os.environ["PYTHONPATH"] to read the existing
value (os.environ.get("PYTHONPATH")), and set os.environ["PYTHONPATH"] to either
"<existing> + os.pathsep + repo_root" if existing is non-empty or just
"repo_root" if not, using os.pathsep instead of a hardcoded ":" to be portable
and avoid introducing an empty path entry.
| <<<<<<< HEAD | ||
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" | ||
| ======= | ||
| """Shared Triton kernels for modelopt (attention, quantization, etc.).""" | ||
|
|
||
| import torch | ||
|
|
||
| from modelopt.torch.utils import import_plugin | ||
|
|
||
| IS_AVAILABLE = False | ||
| attention = None | ||
| attention_with_lse = None | ||
| register_triton_attention = None | ||
|
|
||
| if torch.cuda.is_available(): | ||
| with import_plugin( | ||
| "triton", | ||
| msg_if_missing=( | ||
| "Your device is potentially capable of using the triton attention " | ||
| "kernel. Try to install triton with `pip install triton`." | ||
| ), | ||
| ): | ||
| from .triton_fa import attention as _attention | ||
| from .triton_fa import attention_with_lse as _attention_with_lse | ||
|
|
||
| attention = _attention | ||
| attention_with_lse = _attention_with_lse | ||
| IS_AVAILABLE = True | ||
| with import_plugin("transformers"): | ||
| from .hf_triton_attention import register_triton_attention as _register_triton_attention | ||
|
|
||
| register_triton_attention = _register_triton_attention | ||
|
|
||
| __all__ = [ | ||
| "IS_AVAILABLE", | ||
| "attention", | ||
| "attention_with_lse", | ||
| "register_triton_attention", | ||
| ] | ||
| >>>>>>> 5ea4c609 (Added MLA support) |
There was a problem hiding this comment.
Resolve merge conflict in module initialization.
The module has unresolved merge-conflict markers at line 16. The HEAD version has a different docstring, while the incoming version introduces conditional CUDA-gated plugin loading. Resolve the conflict to finalize the module structure.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/kernels/__init__.py` around lines 16 - 55, Remove the leftover
merge conflict markers and keep the completed initialization logic from the
incoming change: use the descriptive module docstring, import torch and
import_plugin, define IS_AVAILABLE=False and
attention/attention_with_lse/register_triton_attention=None, then if
torch.cuda.is_available() use import_plugin("triton") to import .triton_fa
attention and attention_with_lse, set IS_AVAILABLE=True, and inside an
import_plugin("transformers") block import and assign register_triton_attention
from .hf_triton_attention; finally ensure __all__ contains "IS_AVAILABLE",
"attention", "attention_with_lse", and "register_triton_attention".
| <<<<<<< HEAD:modelopt/torch/kernels/common/attention/triton_fa.py | ||
| __all__ = ["LOG2E", "_apply_mask", "attention"] | ||
| ======= |
There was a problem hiding this comment.
Resolve merge conflict in __all__ exports.
Determine the final export list. The incoming branch adds attention_with_lse (new API returning LSE). Verify whether LOG2E and _apply_mask from HEAD are needed for external consumers.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/kernels/common/attention/triton_fa.py` around lines 1323 -
1325, There is a merge conflict in the module exports: reconcile the __all__ to
export the correct public API by including the functions consumers need—ensure
LOG2E and _apply_mask are exported only if they are intended public symbols, and
add the new API attention_with_lse if the incoming branch introduced it; update
the __all__ to something like ["LOG2E", "_apply_mask", "attention",
"attention_with_lse"] (omit LOG2E/_apply_mask if they are internal), and keep
the symbol names exactly as defined in this file (LOG2E, _apply_mask, attention,
attention_with_lse) so imports work correctly.
| <<<<<<< HEAD | ||
| "VSA_DEFAULT", | ||
| ======= | ||
| "SPARSE_SOFTMAX_SKIP_DEFAULT", | ||
| >>>>>>> 48eb4ea2 (First commit) |
There was a problem hiding this comment.
Resolve merge conflict in __all__ export list.
The __all__ list has unresolved merge-conflict markers. HEAD exports VSA_DEFAULT, while incoming exports SPARSE_SOFTMAX_SKIP_DEFAULT. Determine whether both should be exported or choose one.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 728 - 732,
The __all__ export list contains unresolved merge markers and is currently
choosing between VSA_DEFAULT and SPARSE_SOFTMAX_SKIP_DEFAULT; remove the
conflict markers and export the correct names — include both VSA_DEFAULT and
SPARSE_SOFTMAX_SKIP_DEFAULT in the __all__ list (no merge markers, no
duplicates) so both symbols are publicly exported from the module.
| "uvloop>=0.22.1", | ||
| "vllm==0.20.1", |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if vllm/uvloop imports are optional or required at modelopt import time
rg -n -C3 'import vllm|from vllm|import uvloop|from uvloop' \
--type=py \
--glob '!examples/**' \
--glob '!tests/**' \
modelopt/Repository: NVIDIA/Model-Optimizer
Length of output: 3840
🏁 Script executed:
# Check current dependency structure in pyproject.toml
cat pyproject.toml | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 4022
🏁 Script executed:
# Check if vllm/uvloop are imported unconditionally at package initialization
find modelopt -name "__init__.py" -type f | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 786
🏁 Script executed:
# Search for uvloop usage
rg -n 'uvloop' --type=py modelopt/Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check how plugins are loaded - look for plugin registration/import logic
rg -n 'plugin|import.*vllm' modelopt/ -A2 -B2 | grep -E '(def|class|import|plugin)' | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 2799
🏁 Script executed:
# Check if vllm plugins are imported unconditionally in __init__.py files
cat modelopt/torch/quantization/plugins/__init__.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2617
🏁 Script executed:
# Check sparsity plugins init
find modelopt -path "*sparsity*plugins*__init__.py" -type f -exec cat {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 2287
🏁 Script executed:
# More thorough search for uvloop - check all Python files
rg -n 'uvloop' . --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 362
🏁 Script executed:
# Check torch module __init__ to see what is imported at package init
head -150 modelopt/torch/__init__.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2375
🏁 Script executed:
# Check import_plugin implementation to understand gating mechanism
rg -n 'def import_plugin' modelopt/ -A15Repository: NVIDIA/Model-Optimizer
Length of output: 1307
🏁 Script executed:
# Check if sparsity/attention_sparsity/plugins/vllm.py is imported anywhere
rg -n 'attention_sparsity.*vllm|from.*attention_sparsity.*plugins' modelopt/ -B2 -A2Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if sparsity has its own vllm plugin import
find modelopt/torch/sparsity -name "__init__.py" -exec grep -l "vllm\|import_plugin" {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 189
🏁 Script executed:
# Check if sparsity/attention_sparsity/plugins/vllm.py is actually imported/used
cat modelopt/torch/sparsity/attention_sparsity/plugins/__init__.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1343
🏁 Script executed:
# Double-check: Is there any import of sparsity.attention_sparsity.plugins.vllm anywhere?
rg -n 'sparsity.attention_sparsity.plugins.vllm|from.*sparsity.*attention.*vllm' .Repository: NVIDIA/Model-Optimizer
Length of output: 183
🏁 Script executed:
# Check the sparsity module __init__.py to see what's imported
cat modelopt/torch/sparsity/__init__.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1112
Move uvloop to example-specific requirements and consider making vllm optional.
Per coding guidelines, uvloop is used only in examples/vllm_serve/ scripts and should be listed in examples/vllm_serve/requirements.txt, not in core pyproject.toml dependencies. It is not imported anywhere in modelopt core code.
For vllm: While it is used in quantization plugins (gated via import_plugin("vllm")), it's only accessed by code explicitly importing the vllm plugin. Consider moving to [project.optional-dependencies] with a note that it's required for vLLM sparse-attention integration. Additionally, the exact version pin (==0.20.1) is fragile for a runtime dependency; consider using a version range (~=0.20.1 or >=0.20.1,<0.21).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@pyproject.toml` around lines 52 - 53, Remove "uvloop>=0.22.1" from the core
pyproject.toml dependencies and add it to the
examples/vllm_serve/requirements.txt since uvloop is only used by the example
scripts (refer to examples/vllm_serve/). For "vllm==0.20.1", move it out of core
dependencies into [project.optional-dependencies] (or an extras group) and
document that it is required only for the vLLM quantization plugin invoked via
import_plugin("vllm"); also relax the pin to a range such as ~=0.20.1 or
>=0.20.1,<0.21 to avoid a fragile exact-match dependency.
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
a90c9d6 to
ea47408
Compare
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
705-715: ⚡ Quick winAlign skip-softmax field docs with the new combined preset.
SPARSE_SOFTMAX_SKIP_DEFAULTsetsskip_softmax_thresholdwithmethod="triton_sparse_softmax", but theskip_softmax_thresholdfield description still says it is only used bytriton_skip_softmax. Please update that description to include combined sparse+skip Triton usage.♻️ Suggested doc update
- "Tiles contributing less than this fraction are skipped entirely. " - "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. " + "Tiles contributing less than this fraction are skipped entirely. " + "Used by Triton skip-softmax modes (e.g., triton_skip_softmax and " + "triton_sparse_softmax when skip-softmax is enabled). Typical values: 1e-3 to 1e-1. " "Set to 0 to disable."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 705 - 715, Update the documentation string for the skip_softmax_threshold field to reflect that it is used not only by "triton_skip_softmax" but also by the combined Triton preset "triton_sparse_softmax" (as used in SPARSE_SOFTMAX_SKIP_DEFAULT); locate the skip_softmax_threshold description near the configuration definitions and modify the text to state it applies to both triton_skip_softmax and triton_sparse_softmax (combined sparse+skip Triton usage), ensuring examples or notes mention SPARSE_SOFTMAX_SKIP_DEFAULT where appropriate.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/kernels/common/attention/triton_fa.py`:
- Line 838: The paged KV mode (is_paged = k_cache is not None) is incompatible
with autograd because the forward reads from K_cache/V_cache while the backward
uses saved k/v in ctx, leading to incorrect gradients; update attention() to
either (a) add a runtime guard that raises an error when is_paged is true and
(k.requires_grad or v.requires_grad) to prevent silent gradient corruption, or
(b) clearly document the limitation in the attention() docstring (describe that
paged KV mode is inference-only and gradients are unsupported). Reference
is_paged, k_cache, v_cache, ctx, and function attention() when making the
change.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 705-715: Update the documentation string for the
skip_softmax_threshold field to reflect that it is used not only by
"triton_skip_softmax" but also by the combined Triton preset
"triton_sparse_softmax" (as used in SPARSE_SOFTMAX_SKIP_DEFAULT); locate the
skip_softmax_threshold description near the configuration definitions and modify
the text to state it applies to both triton_skip_softmax and
triton_sparse_softmax (combined sparse+skip Triton usage), ensuring examples or
notes mention SPARSE_SOFTMAX_SKIP_DEFAULT where appropriate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8012e290-b808-40e5-8ad6-ec21855c1bde
📒 Files selected for processing (8)
examples/vllm_serve/sparse_attn_worker.pymodelopt/torch/kernels/common/attention/triton_fa.pymodelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.pymodelopt/torch/quantization/plugins/vllm.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pypyproject.tomltests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
🚧 Files skipped from review as they are similar to previous changes (4)
- pyproject.toml
- modelopt/torch/quantization/plugins/vllm.py
- examples/vllm_serve/sparse_attn_worker.py
- tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
| kv_group_num = num_q_heads // num_kv_heads | ||
| batch = b_seq_len.shape[0] | ||
|
|
||
| is_paged = k_cache is not None |
There was a problem hiding this comment.
Paged KV mode is incompatible with autograd; consider adding a guard or documenting the limitation.
When is_paged=True, the forward kernel reads from K_cache/V_cache, but the backward kernels read from the contiguous k/v tensors saved in ctx. If these differ, gradients will be silently incorrect.
For the vLLM inference use case this PR targets, this is fine (no gradients needed). However, consider either:
- Adding a runtime guard when
k.requires_grad or v.requires_gradwith paged mode, or - Documenting this limitation in the
attention()docstring's paged KV parameters.
📝 Suggested documentation update
k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim].
When provided, K/V are read from paged cache via block_table
- instead of from contiguous k/v tensors.
+ instead of from contiguous k/v tensors. **Note:** Autograd is not
+ supported in paged mode; backward gradients will be incorrect.
v_cache: Paged V cache [num_blocks, page_size, num_kv_heads, head_dim].🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/kernels/common/attention/triton_fa.py` at line 838, The paged
KV mode (is_paged = k_cache is not None) is incompatible with autograd because
the forward reads from K_cache/V_cache while the backward uses saved k/v in ctx,
leading to incorrect gradients; update attention() to either (a) add a runtime
guard that raises an error when is_paged is true and (k.requires_grad or
v.requires_grad) to prevent silent gradient corruption, or (b) clearly document
the limitation in the attention() docstring (describe that paged KV mode is
inference-only and gradients are unsupported). Reference is_paged, k_cache,
v_cache, ctx, and function attention() when making the change.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1499 +/- ##
===========================================
- Coverage 77.44% 62.55% -14.89%
===========================================
Files 473 474 +1
Lines 51418 52331 +913
===========================================
- Hits 39819 32736 -7083
- Misses 11599 19595 +7996
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Builds on vLLM integration PR :
Type of change: new feature
Usage
Summary by CodeRabbit
New Features
Dependencies
uvloopandvllm==0.20.1dependencies.Tests