Skip to content

[OMNIML-4730] Support quantized nn.Embedding#1495

Open
ajrasane wants to merge 3 commits into
mainfrom
ajrasane/quant_embedding
Open

[OMNIML-4730] Support quantized nn.Embedding#1495
ajrasane wants to merge 3 commits into
mainfrom
ajrasane/quant_embedding

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented May 14, 2026

What does this PR do?

Type of change: new feature

Register nn.Embedding in QuantModuleRegistry so the embedding table and lookup activations participate in quantization end-to-end:

  • New modelopt/torch/quantization/nn/modules/quant_embedding.py exposes weight_quantizer (embedding table), output_quantizer (lookup activations, off by default), and an input_quantizer placeholder. Embedding inputs are integer indices that cannot be fake-quantized, so direct enable() / enable_quant() / enable_calib() calls on input_quantizer raise, and forward() raises if _disabled is flipped via any back door. Wildcard configs (*input_quantizer) are accepted silently so the stock deny-all → enable-wildcards → opt-out pattern in NVFP4_DEFAULT_CFG and friends still works.
  • default_disabled_quantizers.yaml installs parent_class: nn.Embedding, enable: false so embedding quantization is opt-in and existing model behavior is unchanged.
  • is_quantized_linear in core_utils.py early-returns False for nn.Embedding so AWQ / SmoothQuant / SVDQuant don't treat it as a GEMM op.
  • _process_quantized_modules in unified_export_hf.py routes quantized nn.Embedding modules through _export_quantized_weight, so the exported checkpoint contains the packed NVFP4 / FP8 / INT bytes plus weight_scale* buffers, exactly like Linear layers.

Usage

import copy
import torch.nn as nn
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint

# Opt embeddings into the stock NVFP4 config — the YAML default is opt-out.
cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG)
cfg["quant_cfg"].append(
    {
        "parent_class": "nn.Embedding",
        "quantizer_name": "*weight_quantizer",
        "cfg": {"num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}},
    }
)

model = mtq.quantize(model, cfg, forward_loop)
export_hf_checkpoint(model, export_dir="./out")
# out/model.safetensors contains: embedding.weight (uint8, NVFP4-packed),
# embedding.weight_scale (FP8 E4M3 per-block), embedding.weight_scale_2 (FP32).

Testing

  • New unit tests tests/unit/torch/quantization/test_quant_embedding.py cover: default quantizer state, no-quant identity, per-tensor and per-row weight fake quant against the manual tensor_quant.fake_tensor_quant reference, output quantizer activation, locked-mutator raises (parametrized over enable / enable_quant / enable_calib), forward-time guard for back-door _disabled = False, and the wildcard-then-opt-out pattern. All 9 cases pass.
  • Verified end-to-end on an embedding-only model: mtq.quantize with NVFP4_DEFAULT_CFG + the embedding opt-in produces embedding.weight (uint8), embedding.weight_scale (float8_e4m3fn), embedding.weight_scale_2 (float32) in the exported safetensors, with "quant_algo": "NVFP4" in hf_quant_config.json.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ — embedding quantizers are opt-in via parent_class: nn.Embedding, enable: false in default_disabled_quantizers.yaml, so existing model behavior is unchanged.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ❌ — will run /claude review after the PR is up.

Additional Information

Summary by CodeRabbit

  • New Features

    • Opt-in quantization for embedding layers: configurable weight quantization and optional output quantization; input quantization is permanently disabled by default.
    • Quantized embedding weights are packed into exported checkpoints alongside other quantized layers.
  • Bug Fixes

    • Preserve tied embedding weights during export (packing skipped with a warning) to avoid breaking weight ties.
  • Tests

    • Added tests for embedding quantization behavior, export packing, calibration, and tied-weight scenarios.
  • Documentation

    • Changelog updated with 0.45 embedding quantization notes.

Review Change Stack

Register nn.Embedding in QuantModuleRegistry so the embedding table and
the lookup activations participate in quantization. The literal input is
integer indices, so input_quantizer is a non-configurable placeholder
that raises on direct enable*() calls and at forward-time if its
_disabled flag is flipped — wildcard configs (e.g. NVFP4_DEFAULT_CFG's
*input_quantizer) are accepted silently so the stock deny-all → enable
wildcards → opt-out pattern continues to work, and the opt-out is
installed by default (parent_class: nn.Embedding in
default_disabled_quantizers.yaml). export_hf_checkpoint packs quantized
embedding weights through the same path as Linear layers.

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane requested review from a team as code owners May 14, 2026 17:54
@ajrasane ajrasane requested a review from cjluo-nv May 14, 2026 17:54
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 14, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5c1d1f97-32cd-4756-a7ee-b9480aebee60

📥 Commits

Reviewing files that changed from the base of the PR and between 7ed34d5 and 4c4db31.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/nn/modules/quant_embedding.py
  • tests/unit/torch/quantization/test_quant_embedding.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/quantization/nn/modules/quant_embedding.py
  • tests/unit/torch/quantization/test_quant_embedding.py

📝 Walkthrough

Walkthrough

Adds QuantEmbedding (quantized nn.Embedding) with gated weight quantization, a permanently disabled input-quantizer, optional output quantization, export packing support (with tied-weight skip), calibration exclusion, default-disabled config entry, unit tests, and a changelog entry.

Changes

Quantized Embedding Support

Layer / File(s) Summary
QuantEmbedding Core Implementation
modelopt/torch/quantization/nn/modules/quant_embedding.py
Introduces _QuantEmbedding with weight quantizer (gated by quantize_weight() / export-mode), an _UnsettableInputQuantizer that raises on enable attempts, optional output_quantizer, _get_quantized_weight dynamic backing for weight, and forward() enforcing disabled input quantizer and applying output quantization.
Integration, Export, and Calibration
modelopt/torch/quantization/nn/__init__.py, modelopt/torch/export/unified_export_hf.py, modelopt/torch/quantization/utils/core_utils.py, modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml, CHANGELOG.rst
Re-exports QuantEmbedding, excludes nn.Embedding from linear-module detection (is_quantized_linear), adds export branch to pack quantized embedding weights under fsdp2_aware_weight_update (skips packing when .weight is tied), inserts default-disabled quantizer YAML entry for nn.Embedding, and documents the feature in the changelog.
Tests and Export Checks
tests/unit/torch/quantization/test_quant_embedding.py, CHANGELOG.rst
Adds unit tests validating default quantizer states, weight fake-quant (per-tensor and per-row), output-quantizer application and calibration behavior, input-quantizer enablement protections, wildcard/config behavior, and export packing vs tied-weight skipping; updates changelog entry.
sequenceDiagram
  participant Client
  participant QuantEmbedding
  participant WeightQuantizer
  participant OutputQuantizer
  Client->>QuantEmbedding: forward(input_indices)
  QuantEmbedding->>QuantEmbedding: ensure input_quantizer disabled
  QuantEmbedding->>WeightQuantizer: get quantized weight (if enabled or export)
  alt quantized weight returned
    QuantEmbedding->>QuantEmbedding: lookup with quantized weight
  else raw weight used
    QuantEmbedding->>QuantEmbedding: lookup with raw weight
  end
  QuantEmbedding->>OutputQuantizer: apply output quantizer if enabled
  QuantEmbedding-->>Client: return embeddings
Loading

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and accurately describes the main change: adding support for quantized nn.Embedding modules throughout the codebase (registration, forward/export implementation, tests, configuration).
Docstring Coverage ✅ Passed Docstring coverage is 95.45% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All changes comply with SECURITY.md. No unsafe deserialization, hardcoded secrets, RCE vectors, nosec bypasses, or non-permissive dependencies detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ajrasane/quant_embedding

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 14, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1495/

Built to branch gh-pages at 2026-05-14 19:30 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@ajrasane ajrasane changed the title feat(quant): support quantized nn.Embedding [OMNIML-4730] Support quantized nn.Embedding May 14, 2026
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Quantized nn.Embedding support cleanly mirrors the _QuantLinear/QuantLinearConvBase pattern (dynamic-attribute weight, quantize_weight context, _register_temp_attribute), so design-wise it slots in well — no second composition system. The wildcard-tolerance via _UnsettableInputQuantizer is unusual but justified: stock recipes apply *input_quantizer enables, and the YAML parent_class: nn.Embedding, *, enable: false rule is appended last in every preset that uses _default_disabled_quantizer_cfg, so the disabled state is restored before forward. Unit tests cover the lock semantics and weight quant against fake_tensor_quant reference.

Three things worth a maintainer look before approving:

  1. output_quantizer is silently bypassed under torch.export. _QuantEmbedding.forward does if is_torch_export_mode(): return super().forward(...) — that path never calls self.output_quantizer(output). QuantLinearConvBase/QuantInputBase both keep the output_quantizer in the export path. If a user opts into output_quantizer and then torch.exports, they'll lose it without warning. Probably harmless today (output_quantizer is off by default) but it's an inconsistency.

  2. Tied embeddings (tied_word_embeddings=True) likely break on export. _export_quantized_weight does setattr(sub_module, weight_name, nn.Parameter(quantized_weight, ...)), replacing embedding.weight with a new Parameter holding packed uint8 bytes. If lm_head.weight was tied to the same Parameter, the tie is severed and lm_head keeps a stale float weight; postprocess_state_dict's tied-weight dedup will then drop one of the keys from the safetensors output. The PR description's example uses an embedding-only model, which sidesteps this — but in real LLMs (Llama/Qwen with tied embeddings) this needs at least a guard or explicit warning.

  3. No export-path test. All new tests are pure forward tests; the new _process_quantized_modules branch routing nn.Embedding through _export_quantized_weight has no coverage. Given (2), an export round-trip test on a tiny tied-embedding model would catch the issue. The PR description says it was verified manually on an embedding-only model — that's exactly the case that doesn't exercise the tying path.

Smaller/optional: the _UnsettableInputQuantizer.enable* overrides catch user-facing direct calls, but set_from_attribute_config({"enable": True}) writes _disabled directly via setattr, so the only real defense is the runtime check in forward. The current docstring already explains this; just confirm the runtime guard is the load-bearing one and the method overrides are belt-and-suspenders.

ajrasane added 2 commits May 14, 2026 19:17
- Apply output_quantizer in the torch.export branch of _QuantEmbedding.forward
  so users who opt into output activation quantization don't silently lose it
  during export. Matches QuantInputBase.forward's behavior.
- Detect Python-level weight tying (e.g. tied_word_embeddings → lm_head) in
  _process_quantized_modules and skip packing the embedding when the .weight
  Parameter is shared, with a UserWarning. Packing would otherwise reassign
  the embedding's .weight to a new uint8 Parameter, severing the tie and
  leaving the tied module pointing at a stale float Parameter.
- Add export-path tests covering the normal pack flow (weight → uint8 +
  weight_scale + weight_scale_2 buffers) and the tied-embedding skip path
  (weight unchanged, warning raised, tie preserved).

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
The previous design raised in _QuantEmbedding.forward whenever
input_quantizer.is_enabled, on the theory that any non-disable config was
an explicit user mistake. That assumption was wrong for wildcard configs:
the default QuantizeConfig is just [{"quantizer_name": "*", "cfg":
{"num_bits": 8, ...}}] (no embedding opt-out), so the wildcard enables
embed_tokens.input_quantizer for tiny Llama-style tests and the forward
guard fires — breaking test_peft_save_load and test_transformers_save_load.

Switch _UnsettableInputQuantizer.set_from_attribute_config to absorb the
incoming config like a normal quantizer, then force _disabled = True at
the end. The "throw on explicit set" semantics are preserved via the
.enable / .enable_quant / .enable_calib overrides, which catch the direct
mistakes users would actually make. The forward-time guard (and the
corresponding test) are removed since the invariant is now maintained at
the configure step.

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants