[OMNIML-4730] Support quantized nn.Embedding#1495
Conversation
Register nn.Embedding in QuantModuleRegistry so the embedding table and the lookup activations participate in quantization. The literal input is integer indices, so input_quantizer is a non-configurable placeholder that raises on direct enable*() calls and at forward-time if its _disabled flag is flipped — wildcard configs (e.g. NVFP4_DEFAULT_CFG's *input_quantizer) are accepted silently so the stock deny-all → enable wildcards → opt-out pattern continues to work, and the opt-out is installed by default (parent_class: nn.Embedding in default_disabled_quantizers.yaml). export_hf_checkpoint packs quantized embedding weights through the same path as Linear layers. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds QuantEmbedding (quantized nn.Embedding) with gated weight quantization, a permanently disabled input-quantizer, optional output quantization, export packing support (with tied-weight skip), calibration exclusion, default-disabled config entry, unit tests, and a changelog entry. ChangesQuantized Embedding Support
sequenceDiagram
participant Client
participant QuantEmbedding
participant WeightQuantizer
participant OutputQuantizer
Client->>QuantEmbedding: forward(input_indices)
QuantEmbedding->>QuantEmbedding: ensure input_quantizer disabled
QuantEmbedding->>WeightQuantizer: get quantized weight (if enabled or export)
alt quantized weight returned
QuantEmbedding->>QuantEmbedding: lookup with quantized weight
else raw weight used
QuantEmbedding->>QuantEmbedding: lookup with raw weight
end
QuantEmbedding->>OutputQuantizer: apply output quantizer if enabled
QuantEmbedding-->>Client: return embeddings
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Quantized nn.Embedding support cleanly mirrors the _QuantLinear/QuantLinearConvBase pattern (dynamic-attribute weight, quantize_weight context, _register_temp_attribute), so design-wise it slots in well — no second composition system. The wildcard-tolerance via _UnsettableInputQuantizer is unusual but justified: stock recipes apply *input_quantizer enables, and the YAML parent_class: nn.Embedding, *, enable: false rule is appended last in every preset that uses _default_disabled_quantizer_cfg, so the disabled state is restored before forward. Unit tests cover the lock semantics and weight quant against fake_tensor_quant reference.
Three things worth a maintainer look before approving:
-
output_quantizeris silently bypassed undertorch.export._QuantEmbedding.forwarddoesif is_torch_export_mode(): return super().forward(...)— that path never callsself.output_quantizer(output).QuantLinearConvBase/QuantInputBaseboth keep the output_quantizer in the export path. If a user opts intooutput_quantizerand thentorch.exports, they'll lose it without warning. Probably harmless today (output_quantizer is off by default) but it's an inconsistency. -
Tied embeddings (
tied_word_embeddings=True) likely break on export._export_quantized_weightdoessetattr(sub_module, weight_name, nn.Parameter(quantized_weight, ...)), replacingembedding.weightwith a new Parameter holding packed uint8 bytes. Iflm_head.weightwas tied to the same Parameter, the tie is severed andlm_headkeeps a stale float weight;postprocess_state_dict's tied-weight dedup will then drop one of the keys from the safetensors output. The PR description's example uses an embedding-only model, which sidesteps this — but in real LLMs (Llama/Qwen with tied embeddings) this needs at least a guard or explicit warning. -
No export-path test. All new tests are pure forward tests; the new
_process_quantized_modulesbranch routingnn.Embeddingthrough_export_quantized_weighthas no coverage. Given (2), an export round-trip test on a tiny tied-embedding model would catch the issue. The PR description says it was verified manually on an embedding-only model — that's exactly the case that doesn't exercise the tying path.
Smaller/optional: the _UnsettableInputQuantizer.enable* overrides catch user-facing direct calls, but set_from_attribute_config({"enable": True}) writes _disabled directly via setattr, so the only real defense is the runtime check in forward. The current docstring already explains this; just confirm the runtime guard is the load-bearing one and the method overrides are belt-and-suspenders.
- Apply output_quantizer in the torch.export branch of _QuantEmbedding.forward so users who opt into output activation quantization don't silently lose it during export. Matches QuantInputBase.forward's behavior. - Detect Python-level weight tying (e.g. tied_word_embeddings → lm_head) in _process_quantized_modules and skip packing the embedding when the .weight Parameter is shared, with a UserWarning. Packing would otherwise reassign the embedding's .weight to a new uint8 Parameter, severing the tie and leaving the tied module pointing at a stale float Parameter. - Add export-path tests covering the normal pack flow (weight → uint8 + weight_scale + weight_scale_2 buffers) and the tied-embedding skip path (weight unchanged, warning raised, tie preserved). Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
The previous design raised in _QuantEmbedding.forward whenever
input_quantizer.is_enabled, on the theory that any non-disable config was
an explicit user mistake. That assumption was wrong for wildcard configs:
the default QuantizeConfig is just [{"quantizer_name": "*", "cfg":
{"num_bits": 8, ...}}] (no embedding opt-out), so the wildcard enables
embed_tokens.input_quantizer for tiny Llama-style tests and the forward
guard fires — breaking test_peft_save_load and test_transformers_save_load.
Switch _UnsettableInputQuantizer.set_from_attribute_config to absorb the
incoming config like a normal quantizer, then force _disabled = True at
the end. The "throw on explicit set" semantics are preserved via the
.enable / .enable_quant / .enable_calib overrides, which catch the direct
mistakes users would actually make. The forward-time guard (and the
corresponding test) are removed since the invariant is now maintained at
the configure step.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
What does this PR do?
Type of change: new feature
Register
nn.EmbeddinginQuantModuleRegistryso the embedding table and lookup activations participate in quantization end-to-end:modelopt/torch/quantization/nn/modules/quant_embedding.pyexposesweight_quantizer(embedding table),output_quantizer(lookup activations, off by default), and aninput_quantizerplaceholder. Embedding inputs are integer indices that cannot be fake-quantized, so directenable()/enable_quant()/enable_calib()calls oninput_quantizerraise, andforward()raises if_disabledis flipped via any back door. Wildcard configs (*input_quantizer) are accepted silently so the stock deny-all → enable-wildcards → opt-out pattern inNVFP4_DEFAULT_CFGand friends still works.default_disabled_quantizers.yamlinstallsparent_class: nn.Embedding, enable: falseso embedding quantization is opt-in and existing model behavior is unchanged.is_quantized_linearincore_utils.pyearly-returnsFalsefornn.Embeddingso AWQ / SmoothQuant / SVDQuant don't treat it as a GEMM op._process_quantized_modulesinunified_export_hf.pyroutes quantizednn.Embeddingmodules through_export_quantized_weight, so the exported checkpoint contains the packed NVFP4 / FP8 / INT bytes plusweight_scale*buffers, exactly like Linear layers.Usage
Testing
tests/unit/torch/quantization/test_quant_embedding.pycover: default quantizer state, no-quant identity, per-tensor and per-row weight fake quant against the manualtensor_quant.fake_tensor_quantreference, output quantizer activation, locked-mutator raises (parametrized overenable/enable_quant/enable_calib), forward-time guard for back-door_disabled = False, and the wildcard-then-opt-out pattern. All 9 cases pass.mtq.quantizewithNVFP4_DEFAULT_CFG+ the embedding opt-in producesembedding.weight (uint8),embedding.weight_scale (float8_e4m3fn),embedding.weight_scale_2 (float32)in the exported safetensors, with"quant_algo": "NVFP4"inhf_quant_config.json.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).parent_class: nn.Embedding, enable: falseindefault_disabled_quantizers.yaml, so existing model behavior is unchanged.CONTRIBUTING.md: N/A/claude reviewafter the PR is up.Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation