Skip to content

perf: optimize guided decoding with xgrammar upgrade, batched API, and async D2H overlap#4605

Open
windreamer wants to merge 6 commits into
InternLM:mainfrom
windreamer:feat/guided-decoding-optimization
Open

perf: optimize guided decoding with xgrammar upgrade, batched API, and async D2H overlap#4605
windreamer wants to merge 6 commits into
InternLM:mainfrom
windreamer:feat/guided-decoding-optimization

Conversation

@windreamer

@windreamer windreamer commented May 21, 2026

Copy link
Copy Markdown
Collaborator

Motivation

Guided decoding in TurboMind has several performance bottlenecks: per-matcher loops for FillNextTokenBitmask/AcceptToken, synchronous D2H copy blocking the main stream, and a GrammarCompiler instantiated per request. This PR addresses all three to reduce guided decoding overhead.

Modification

  1. Upgrade xgrammar from v0.1.27 to v0.2.1 — enables the batched API (BatchFillNextTokenBitmask, BatchAcceptToken).
  2. Batched grammar operations (C++) — replace per-matcher FillNextTokenBitmask/AcceptToken loops with BatchGrammarMatcher::BatchFillNextTokenBitmask and BatchAcceptToken, reducing per-token overhead proportional to batch size.
  3. Overlap D2H copy with GPU work — issue the output_ids D2H copy on an independent CUDA stream so it overlaps with AppendTokenIds + stop_criteria on the main stream, hiding the copy latency.
  4. Lazy-shared GrammarCompiler (Python) — create GrammarCompiler once per TurboMind instance (lazily) instead of per request, avoiding repeated tokenizer introspection.
  5. Fix PyTorch engine accept_token — convert tensor to Python int (.item()) before passing to xgrammar, fixing a type mismatch with the new API.
  6. CMake: change target_link_libraries from PRIVATE to PUBLIC for the guided_decoding static library so that dependent targets correctly propagate xgrammar headers.

BC-breaking (Optional)

None. The API surface is unchanged; only internal guided decoding paths are affected.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@windreamer windreamer force-pushed the feat/guided-decoding-optimization branch from 9a942b6 to 3fdff4a Compare May 21, 2026 04:23
@windreamer windreamer changed the title perf: speed up guided decoding with xgrammar new version and batched update perf: optimize guided decoding with xgrammar upgrade, batched API, and async D2H overlap May 21, 2026
@windreamer windreamer marked this pull request as ready for review May 21, 2026 08:15
Copilot AI review requested due to automatic review settings May 21, 2026 08:15

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes guided decoding performance in TurboMind by upgrading xgrammar and refactoring guided-decoding paths to use batched matcher APIs plus CUDA stream/event orchestration to overlap host-device transfers with GPU work. It also reduces Python-side overhead by reusing a lazily constructed GrammarCompiler and fixes a PyTorch guided-decoding type mismatch introduced by the xgrammar upgrade.

Changes:

  • Upgrade xgrammar to v0.2.1 and switch C++ guided decoding to batched matcher APIs (BatchFillNextTokenBitmask / BatchAcceptToken).
  • Overlap output_ids D2H copies with GPU kernels via a secondary CUDA stream and split guided decoding update into ScheduleUpdate + FinishUpdate.
  • Cache GrammarCompiler per TurboMind instance (lazy init) and fix PyTorch accept_token to pass a Python int via .item().

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/turbomind/generation/guided_decoding.h Adds batched matcher + CUDA stream/event members; splits Update into two phases.
src/turbomind/generation/guided_decoding.cc Implements batched xgrammar calls and async D2H overlap using events/streams; adds needs_apply gating.
src/turbomind/generation/generation.cc Integrates ScheduleUpdate/FinishUpdate around AppendTokenIds and stop_criteria to enable overlap.
src/turbomind/generation/CMakeLists.txt Exposes xgrammar/core linkage publicly for guided_decoding consumers.
lmdeploy/turbomind/turbomind.py Introduces lazy-shared GrammarCompiler and removes per-request instantiation.
lmdeploy/pytorch/engine/logits_process.py Passes token id as Python int (.item()) to guided decoding manager.
CMakeLists.txt Bumps FetchContent xgrammar tag to v0.2.1.
Comments suppressed due to low confidence (1)

src/turbomind/generation/guided_decoding.cc:135

  • Similarly, FinishUpdate() allocates active_matchers and active_token_ids every step without reserving. Reserving (or persisting these vectors in the phase Data) would reduce per-token allocation overhead, especially for large batch sizes.
            // Collect active matchers and their token IDs for batch AcceptToken
            std::vector<xgrammar::GrammarMatcher> active_matchers;
            std::vector<int32_t>                  active_token_ids;


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/turbomind/generation/guided_decoding.cc Outdated
Comment thread src/turbomind/generation/guided_decoding.cc Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (2)

src/turbomind/generation/guided_decoding.cc:129

  • FinishUpdate() calls d2h_done_.Sync() on all TP ranks even though only rank 0 performs BatchAcceptToken. On non-zero ranks this host-side wait is pure overhead (and can also introduce unnecessary CPU/GPU synchronization points). Consider moving the sync + matcher update under the tp_group_->rank() == 0 branch, or early-returning for non-zero ranks.
    if (auto& d = *data_.at(phase); d.active) {
        // Wait only for the D2H copy to complete — the main stream's
        // AppendTokenIds + stop_criteria may still be executing on GPU.
        d2h_done_.Sync();

        if (tp_group_->rank() == 0) {

lmdeploy/pytorch/engine/logits_process.py:484

  • The guided-decoding accept_token call site changed to pass a Python int, but there is no unit test in tests/pytorch/engine/test_logits_process.py covering the guided_decoding_manager integration path (e.g., that accept_token is invoked with the expected token values/types). Adding a small test with a stub GuidedDecodingManager would help prevent regressions when sampling runs on CUDA tensors.
        if self.guided_decoding_manager and self.guided_processors:
            for i, processor in self.guided_processors.items():
                self.guided_decoding_manager.accept_token(processor, result[i].item())

Comment thread src/turbomind/generation/guided_decoding.cc
Comment thread lmdeploy/pytorch/engine/logits_process.py Outdated
@windreamer windreamer force-pushed the feat/guided-decoding-optimization branch from b5a3678 to 49c617f Compare May 21, 2026 09:15
@windreamer windreamer requested a review from Copilot May 21, 2026 09:29

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

src/turbomind/generation/guided_decoding.cc:139

  • FinishUpdate() iterates over all d.matchers and calls batch AcceptToken using output_ids_buf_[i] for every non-terminated matcher. However, only the first generation_size sequences receive a newly sampled token each step; for the remaining slots output_ids_buf_ may contain stale values (since sampling runs with batch_size = logits.shape(0)). This can advance grammar state incorrectly for sequences that were not generating this step. Limit the loop to generation_size (saved from ScheduleUpdate()), or gate on the per-request generating mask.
    if (auto& d = *data_.at(phase); d.active && tp_group_->rank() == 0) {
        // Wait only for the D2H copy to complete — the main stream's
        // AppendTokenIds + stop_criteria may still be executing on GPU.
        d2h_done_.Sync();

        // Collect active matchers and their token IDs for batch AcceptToken
        std::vector<xgrammar::GrammarMatcher> active_matchers;
        std::vector<int32_t>                  active_token_ids;
        active_matchers.reserve(d.matchers.size());
        active_token_ids.reserve(d.matchers.size());

        for (size_t i = 0; i < d.matchers.size(); ++i) {
            if (const auto& m = d.matchers[i]; m && !m->IsTerminated()) {
                active_matchers.emplace_back(*m);
                active_token_ids.emplace_back(output_ids_buf_[i]);
            }

Comment thread src/turbomind/generation/guided_decoding.cc
Comment thread src/turbomind/generation/guided_decoding.cc Outdated
windreamer added a commit to windreamer/lmdeploy that referenced this pull request May 21, 2026
FillMask, ScheduleUpdate, and FinishUpdate previously iterated over
d.matchers.size() entries, but only the first generation_size
(= logits.shape(0)) slots are actively generating. Entries beyond
that index contain stale output_ids and unused bitmasks.

- FillMask: limit matcher iteration and reserve to gs = logits.shape(0)
- ScheduleUpdate: copy only gs output_ids entries for D2H transfer
- FinishUpdate: add TensorMap& env param, iterate only over gs slots

Fixes review comments on PR InternLM#4605 (3280137130, 3280137198).
@windreamer windreamer requested a review from Copilot May 21, 2026 10:14

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Comment thread src/turbomind/generation/guided_decoding.cc Outdated
Comment thread src/turbomind/generation/CMakeLists.txt Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated no new comments.

windreamer added a commit to windreamer/lmdeploy that referenced this pull request May 22, 2026
FillMask, ScheduleUpdate, and FinishUpdate previously iterated over
d.matchers.size() entries, but only the first generation_size
(= logits.shape(0)) slots are actively generating. Entries beyond
that index contain stale output_ids and unused bitmasks.

- FillMask: limit matcher iteration and reserve to gs = logits.shape(0)
- ScheduleUpdate: copy only gs output_ids entries for D2H transfer
- FinishUpdate: add TensorMap& env param, iterate only over gs slots

Fixes review comments on PR InternLM#4605 (3280137130, 3280137198).
@windreamer windreamer force-pushed the feat/guided-decoding-optimization branch from ea5c50b to e24c147 Compare May 22, 2026 01:39
windreamer added a commit to windreamer/lmdeploy that referenced this pull request Jun 5, 2026
FillMask, ScheduleUpdate, and FinishUpdate previously iterated over
d.matchers.size() entries, but only the first generation_size
(= logits.shape(0)) slots are actively generating. Entries beyond
that index contain stale output_ids and unused bitmasks.

- FillMask: limit matcher iteration and reserve to gs = logits.shape(0)
- ScheduleUpdate: copy only gs output_ids entries for D2H transfer
- FinishUpdate: add TensorMap& env param, iterate only over gs slots

Fixes review comments on PR InternLM#4605 (3280137130, 3280137198).
@windreamer windreamer force-pushed the feat/guided-decoding-optimization branch from e24c147 to 33563fe Compare June 5, 2026 02:39
windreamer added a commit to windreamer/lmdeploy that referenced this pull request Jun 5, 2026
FillMask, ScheduleUpdate, and FinishUpdate previously iterated over
d.matchers.size() entries, but only the first generation_size
(= logits.shape(0)) slots are actively generating. Entries beyond
that index contain stale output_ids and unused bitmasks.

- FillMask: limit matcher iteration and reserve to gs = logits.shape(0)
- ScheduleUpdate: copy only gs output_ids entries for D2H transfer
- FinishUpdate: add TensorMap& env param, iterate only over gs slots

Fixes review comments on PR InternLM#4605 (3280137130, 3280137198).
@windreamer windreamer force-pushed the feat/guided-decoding-optimization branch from 33563fe to ba4efab Compare June 5, 2026 02:45
@windreamer windreamer marked this pull request as draft June 9, 2026 01:12
… CUDA stream

Split GuidedDecoding::Update() into ScheduleUpdate() + FinishUpdate()
to enable D2H copy of output_ids on a secondary CUDA stream, overlapping
with AppendTokenIds and stop_criteria GPU kernels on the main stream.

- ScheduleUpdate(): records sampling_done event on main stream, launches
  async D2H copy on d2h_stream_ (waits for sampling_done first)
- FinishUpdate(): syncs on d2h_done event, then runs BatchAcceptToken on CPU
- Adds d2h_stream_, sampling_done_, d2h_done_ members (created once in ctor)
- Eliminates the blocking cudaStreamSynchronize that previously stalled the
  CPU between sampling and AcceptToken

This is optimization 5 (Plan I): independent CUDA stream for D2H copy
parallelism, removing a sync point in the decode step hot path.
FillMask, ScheduleUpdate, and FinishUpdate previously iterated over
d.matchers.size() entries, but only the first generation_size
(= logits.shape(0)) slots are actively generating. Entries beyond
that index contain stale output_ids and unused bitmasks.

- FillMask: limit matcher iteration and reserve to gs = logits.shape(0)
- ScheduleUpdate: copy only gs output_ids entries for D2H transfer
- FinishUpdate: add TensorMap& env param, iterate only over gs slots

Fixes review comments on PR InternLM#4605 (3280137130, 3280137198).
@windreamer windreamer force-pushed the feat/guided-decoding-optimization branch from ba4efab to 84a90a2 Compare June 9, 2026 02:41
@windreamer windreamer marked this pull request as ready for review June 9, 2026 04:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants