From 804eb74de75e266302a676af2ab923fba0128c8a Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Thu, 11 Jun 2026 15:19:44 +0530 Subject: [PATCH 1/2] feat(repeatkv): consolidate RepeatKV transform support for LLM/VLM and helper utilities This squashed commit combines the last 9 commits on this branch into one coherent change set for RepeatKV support and follow-up review updates. What changed - Added RepeatKV transform support for both LLM and VLM export paths. - Extended model/config handling to support generic attention head naming patterns (e.g., num_attention_heads, n_heads, n_head) to improve cross-model compatibility. - Added/updated RepeatKV operations and integration paths for DeepSeekV3 flows. - Added support for deriving an effective repeat_kv count based on model topology and number of devices. - Added RepeatKV handling for AWQ quantized models. - Improved wrapper-aware behavior for VLM encoder/decoder paths and prevented repeated application of ReplicateKVTransform across wrappers. - Updated CI-related model mapping/test infra for repeat_kv checks across CausalLM/VLM scenarios and adjusted script flow around APIRunner/input-shape sequencing. - Refactored KV duplication logic into shared helpers: - moved helper logic to transformers/models/repeat_kv_utils.py - centralized projection lookup, MLA checks, idempotency checks, and KV duplication - replaced in-class duplication code with utility-driven calls - Applied naming cleanup by renaming num_kv_heads_repeat to num_replicate_kv_heads. Review-driven updates - Included internal and PR review feedback updates across transform behavior, scripts, naming, and helper factoring. - Incorporated a revert+follow-up sequence from review iteration, keeping only the final intended behavior in this squashed result. Notes - Historical TODOs from intermediate commits were retained as context during development; this squashed state reflects the final net code on branch tip. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 45 ++- .../models/gemma3/modeling_gemma3.py | 1 + .../models/internvl/modeling_internvl.py | 1 + .../models/llama4/modeling_llama4.py | 1 + .../models/llava/modeling_llava.py | 1 + .../models/llava_next/modeling_llava_next.py | 1 + .../models/mistral3/modeling_mistral3.py | 1 + .../transformers/models/modeling_auto.py | 16 +- .../models/molmo/modeling_molmo.py | 1 + .../transformers/models/pytorch_transforms.py | 242 ++++++++++++---- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 + .../models/qwen3_5/modeling_qwen3_5.py | 2 +- .../models/qwen3_vl/modeling_qwen3_vl.py | 1 + .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 1 + QEfficient/utils/config_utils.py | 68 +++++ QEfficient/utils/constants.py | 5 + QEfficient/utils/repeat_kv_utils.py | 270 ++++++++++++++++++ QEfficient/utils/test_utils.py | 29 ++ dbg.log | 0 examples/kimi_k2/README.md | 4 +- examples/kimi_k2/export_kimik2.py | 8 +- examples/text_generation/run_kimik2.py | 8 +- tests/configs/causal_model_configs.json | 15 +- .../causal_lm_models/check_causal_models.py | 67 ++++- .../causal_lm_models/test_causal_lm_models.py | 29 ++ .../test_image_text_to_text_models.py | 75 ++++- 26 files changed, 811 insertions(+), 82 deletions(-) create mode 100644 QEfficient/utils/config_utils.py create mode 100644 QEfficient/utils/repeat_kv_utils.py create mode 100644 dbg.log diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 35d46fb3a5..9e6076209e 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -47,6 +47,7 @@ require_value, to_named_specializations, ) +from QEfficient.utils.config_utils import calculate_num_replicate_kv_heads from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -735,23 +736,48 @@ def transform( **compiler_options, ): # Apply the transformations that are dependent on compilation parameters + def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: + """ + Use the shared wrapped model as transform-tracking root when available. + This lets encoder/decoder wrappers coordinate one-time transforms. + """ + wrapped = getattr(module, "model", None) + return wrapped if isinstance(wrapped, torch.nn.Module) else module qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None) + model_config = getattr(self.model, "config", None) or getattr( + getattr(self.model, "model", None), "config", None + ) + num_replicate_kv_heads = 1 + if model_config is not None: + num_replicate_kv_heads = calculate_num_replicate_kv_heads( + num_devices=num_devices, + text_model_config=model_config, + ) if model_config: - if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): - if qaic_config: - if qaic_config.get("blocking_mode", None) == "h": - qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) - num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + if qaic_config is not None: + num_replicate_kv_heads = qaic_config.get("num_replicate_kv_heads", num_replicate_kv_heads) + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads + transform_root = _transform_tracking_root(self.model) + applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) + should_apply_repeat_kv = num_replicate_kv_heads is not None and num_replicate_kv_heads > 1 + if not should_apply_repeat_kv: + replicate_kv_transformed = False + elif ReplicateKVHeadTransform.__name__ in applied_transforms: + replicate_kv_transformed = False + logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") + else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( - self.model, num_kv_heads_repeat + self.model, + num_replicate_kv_heads, ) if replicate_kv_transformed: - self.hash_params["config"] = self.model.config.to_diff_dict() - + applied_transforms.add(ReplicateKVHeadTransform.__name__) + setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms) + if replicate_kv_transformed: + self.hash_params["config"] = self.model.config.to_diff_dict() blocking_config = build_transformer_blocking_config_for_transform( model_config, ctx_len=ctx_len, @@ -768,6 +794,7 @@ def transform( if blocking_config is not None: self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config) self.hash_params["blocking_kwargs"] = blocking_config + self.hash_params["num_replicate_kv_heads"] = num_replicate_kv_heads @dump_qconfig def _compile( diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 96bf6cd1a0..289a1bee77 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -626,6 +626,7 @@ class QEffGemma3EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model.model + self.config = self.model.config self.model.vision_model = self.model.vision_tower def get_submodules_for_export(self) -> Type[nn.Module]: diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 1496e972b0..af457bb50d 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -20,6 +20,7 @@ class QEffInternEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 3dcd69e79e..6dc411c400 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -831,6 +831,7 @@ class QEffLlama4EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 3cbe77a95b..f45d4847c4 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.model.vision_model = self.model.model.vision_tower + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 9d7881b012..1a66a7c0b7 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.model.vision_model = self.model.model.vision_tower + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index d8cbdb17b6..5029a93984 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -183,6 +183,7 @@ class QEFFMistral3EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config self.model.model.vision_model = self.model.model.vision_tower def get_submodules_for_export(self) -> Type[nn.Module]: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8009de20be..b3d5b0a60b 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1461,6 +1461,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option ) _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -1469,6 +1470,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -1593,7 +1595,12 @@ def export( if prefill_only and prefill_seq_len > 1: offload_pt_weights = False # to keep weight for decode onnx else: - offload_pt_weights = kwargs.get("offload_pt_weights", True) + num_replicate_kv_heads = ( + (self.lang_model.model.qaic_config or {}).get("num_replicate_kv_heads", 1) + if hasattr(self.lang_model.model, "qaic_config") + else 1 + ) + offload_pt_weights = kwargs.get("offload_pt_weights", num_replicate_kv_heads <= 1) if not skip_lang: self.lang_model.export( @@ -2616,6 +2623,7 @@ def from_pretrained( config._attn_implementation = "eager" config.vision_config.use_flash_attn = "false" _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2624,6 +2632,7 @@ def from_pretrained( model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -3275,6 +3284,7 @@ def from_pretrained( ) _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) if layerwise: # Layer-wise mode: build the outer model on the meta device so the # caller's ``from_pretrained`` does not pull the full checkpoint @@ -3293,6 +3303,7 @@ def from_pretrained( continuous_batching=continuous_batching, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) # Mark the wrapper so its compile() can default ``layerwise=True`` if @@ -3550,6 +3561,7 @@ def from_pretrained( ) _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) if layerwise: # Layer-wise mode: build the outer model on the meta device. The # caller still gets a typed wrapper, but no checkpoint weights are @@ -3570,6 +3582,7 @@ def from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, continuous_batching=continuous_batching, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) instance = cls( @@ -3578,6 +3591,7 @@ def from_pretrained( qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, max_seq_len_cached=max_seq_len_cached, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) if layerwise: diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index d59ca4e017..b673d9e060 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -565,6 +565,7 @@ class QEffMolmoEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 211b60fa82..20779c9dac 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -9,7 +9,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -import torch from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -299,7 +298,11 @@ ) from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel -from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.base.pytorch_transforms import ( + ExternalModuleMapperTransform, + ModuleMappingTransform, + ModuleMutatorTransform, +) from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function from QEfficient.transformers.models.bert.modeling_bert import ( @@ -617,6 +620,16 @@ QEffT5LayerNorm, QEffT5Stack, ) + +# from QEfficient.transformers.models.repeat_kv_utils import ( +# duplicate_kv_projection_weights, +# get_attention_module, +# get_projection_layer, +# get_text_model, +# is_mla_model, +# is_replication_applied, +# replication_targets, +# ) from QEfficient.transformers.models.wav2vec2.modeling_wav2vec2 import ( QEffWav2Vec2Encoder, QEffWav2Vec2EncoderStableLayerNorm, @@ -633,7 +646,23 @@ from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.config_utils import ( + resolve_attention_heads, + resolve_hidden_size, + resolve_kv_heads, + set_kv_head_aliases, +) +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, HIDDEN_SIZE_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS from QEfficient.utils.logging_utils import logger +from QEfficient.utils.repeat_kv_utils import ( + duplicate_kv_projection_weights, + get_attention_module, + get_projection_layer, + get_text_model, + is_mla_model, + is_replication_applied, + replication_targets, +) SPD_TARGET = "target" @@ -969,72 +998,187 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): } -class ReplicateKVHeadTransform: +class ReplicateKVHeadTransform(ModuleMutatorTransform): """ Replicates KV heads in attention modules to match the number of KV heads in the target model. This transform is used when the source model has fewer KV heads than required in target model. """ - def _duplicate_weights_for_linear_layer( - layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int - ): - new_kv_heads = repeat # for mla + _module_mapping = { + QEffCodeGenForCausalLM, + QEffFalconForCausalLM, + QEffGPT2LMHeadModel, + QEffGPTJForCausalLM, + QEffLlamaForCausalLM, + QEffLlama4ForConditionalGeneration, + QEffLlavaForConditionalGeneration, + QEffLlavaNextForConditionalGeneration, + QEffGemmaForCausalLM, + QEffGemma2ForCausalLM, + QEffGemma3ForConditionalGeneration, + QEffGlm4MoeForCausalLM, + QEffGraniteForCausalLM, + QEffGraniteMoeForCausalLM, + QEffMllamaForConditionalGeneration, + QEffMistralForCausalLM, + QEffMistral3ForConditionalGeneration, + QEffMixtralForCausalLM, + QEffMptForCausalLM, + QEffPhiForCausalLM, + QEffPhi3ForCausalLM, + QEffQwen2ForCausalLM, + QEffQwen3ForCausalLM, + QEffQwen3_5ForConditionalGeneration, + QEffQwen3_5MoeForConditionalGeneration, + QEffQwen_2_5_vl_ForConditionalGeneration, + QEffQwen3MoeForCausalLM, + QEffQwen3VLForConditionalGeneration, + QEffQwen3VLMoeForConditionalGeneration, + QEffStarcoder2ForCausalLM, + QEffGPTBigCodeForCausalLM, + QEffOlmo2ForCausalLM, + } + _module_string_mapping = { + "DeepseekV3ForCausalLM", + "InternVLChatModel", + "MolmoForCausalLM,", + "QEffGemma3DecoderWrapper", + "QEffGemma3EncoderWrapper", + "QEffInternDecoderWrapper", + "QEffInternEncoderWrapper", + "QEffLlama4DecoderWrapper", + "QEffLlama4EncoderWrapper", + "QEFFLlavaDecoderWrapper", + "QEFFLlavaEncoderWrapper", + "QEffLlavaNextDecoderWrapper", + "QEffLlavaNextEncoderWrapper", + "QEFFMistral3DecoderWrapper", + "QEFFMistral3EncoderWrapper", + "QEffMolmoDecoderWrapper", + "QEffMolmoEncoderWrapper", + "QEffQwen_2_5_vl_DecoderWrapper", + "QEffQwen_2_5_vl_EncoderWrapper", + "QEffQwen3VLDecoderWrapper", + "QEffQwen3VLEncoderWrapper", + "QEffQwen3_5EncoderWrapper", + "QEffQwen3_5DecoderWrapper", + "QEffQwen3_5MoeEncoderWrapper", + "QEffQwen3_5MoeDecoderWrapper", + } - layer.weight.data = torch.repeat_interleave( - layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 - ).view(new_kv_heads * dim, hidden_size) + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: int) -> nn.Module: + """ + Mutates the matched top-level model module in-place by replicating its KV heads. - if layer.bias is not None: - layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( - new_kv_heads * dim - ) + Args: + original_module: The matched top-level model module to mutate. + parent_module: The parent module (unused, present for interface compatibility). + n_repeat: The number of times to repeat the KV heads. - def _get_text_model(model): - """ - Determine and return the appropriate text_model from a given model object. + Returns: + The mutated module (same object, modified in-place). """ - # Check for VLMs - if hasattr(model, "language_model"): - if hasattr(model.language_model, "model"): - return model.language_model.model - else: - return model.language_model - # Check for CausalLMs - if hasattr(model, "model"): - return model.model + text_model = get_text_model(original_module) + if is_replication_applied(original_module, text_model): + logger.warning("KV head replication already applied for this model instance; skipping.") + return original_module + + cfg = text_model.config + if is_mla_model(text_model): + logger.warning("Skipping RepeatKVTransform: MLA models don't apply replicate KV changes.") + return original_module + + orig_kv_heads = resolve_kv_heads(cfg) + num_attention_heads = resolve_attention_heads(cfg) + hidden_size = resolve_hidden_size(cfg) + + if orig_kv_heads is None or num_attention_heads is None or hidden_size is None: + raise ValueError( + "Unable to resolve attention/KV heads or hidden size from config for RepeatKV transform. " + f"Supported attention keys={ATTENTION_HEAD_CONFIG_KEYS}, kv keys={KV_HEAD_CONFIG_KEYS}, " + f"hidden size keys={HIDDEN_SIZE_CONFIG_KEYS}." + ) + if orig_kv_heads < 1 or num_attention_heads < 1: + raise ValueError( + f"Invalid head values for RepeatKV transform: " + f"num_attention_heads={num_attention_heads}, num_key_value_heads={orig_kv_heads}" + ) + new_kv_heads = n_repeat * orig_kv_heads + if new_kv_heads > num_attention_heads or (num_attention_heads % new_kv_heads) != 0: + raise ValueError( + f"Invalid RepeatKV configuration: num_attention_heads={num_attention_heads}, " + f"orig_kv_heads={orig_kv_heads}, num_replicate_kv_heads={n_repeat}, new_kv_heads={new_kv_heads}. " + "Expected new_kv_heads <= num_attention_heads and divisibility." + ) - raise AttributeError("No suitable text model found in the provided model.") + cfg.orig_kv_heads = orig_kv_heads + set_kv_head_aliases(cfg, new_kv_heads) + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + for block in text_model.layers: + attn = get_attention_module(block) + if hasattr(attn, "num_key_value_heads"): + attn.num_key_value_heads = new_kv_heads + if hasattr(attn, "n_kv_heads"): + attn.n_kv_heads = new_kv_heads + + n_kv_groups = num_attention_heads // new_kv_heads + if hasattr(attn, "num_key_value_groups"): + attn.num_key_value_groups = n_kv_groups + if hasattr(attn, "n_kv_groups"): + attn.n_kv_groups = n_kv_groups + head_dim = getattr(attn, "head_dim", hidden_size // num_attention_heads) + k_proj = get_projection_layer(attn, ("k_proj", "key_proj")) + v_proj = get_projection_layer(attn, ("v_proj", "value_proj")) + duplicate_kv_projection_weights( + k_proj, + orig_kv_heads, + n_repeat, + head_dim, + hidden_size, + layer_name=f"{attn.__class__.__name__}.k_proj", + ) + duplicate_kv_projection_weights( + v_proj, + orig_kv_heads, + n_repeat, + head_dim, + hidden_size, + layer_name=f"{attn.__class__.__name__}.v_proj", + ) + + for target in replication_targets(original_module, text_model): + setattr(target, "_qeff_kv_replication_applied", True) + return original_module @classmethod - def apply(cls, model: nn.Module, num_kv_heads_repeat: int = 1) -> nn.Module: + def apply(cls, model: nn.Module, num_replicate_kv_heads: Optional[int] = None, **kwargs) -> Tuple[nn.Module, bool]: """ Replicates KV heads in attention modules based on provided multiplier. Args: model: The model to apply the transform to. - num_kv_heads_repeat: The number of times to repeat the KV heads. + kwargs: Additional arguments for the transformation. Includes: + - num_replicate_kv_heads: The number of times to repeat the KV heads. """ + if num_replicate_kv_heads is None: + n_repeat = kwargs.pop("num_replicate_kv_heads", 1) + else: + kwargs.pop("num_replicate_kv_heads", None) + n_repeat = num_replicate_kv_heads transformed = False - if num_kv_heads_repeat is not None and num_kv_heads_repeat > 1: - text_model = cls._get_text_model(model) - - orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads - new_kv_heads = num_kv_heads_repeat * orig_kv_heads - text_model.config.orig_kv_heads = orig_kv_heads - text_model.config.num_key_value_heads = new_kv_heads - - hidden_size = text_model.config.hidden_size - - logger.warning(f"Original KV heads: {orig_kv_heads}") - logger.warning(f"Modified KV heads: {new_kv_heads}") - transformed = True - for block in text_model.layers: - attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) - attn.num_key_value_heads = new_kv_heads - head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim - - cls._duplicate_weights_for_linear_layer( - attn.kv_a_proj_with_mqa, orig_kv_heads, num_kv_heads_repeat, head_dim, hidden_size + if n_repeat is not None and n_repeat > 1: + if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): + text_model = get_text_model(model) + was_applied = is_replication_applied(model, text_model) + cls.mutate(model, None, n_repeat) + is_applied = is_replication_applied(model, text_model) + transformed = (not was_applied) and is_applied + else: + raise NotImplementedError( + f"Model class {model.__class__.__name__} is not supported for KV head replication." ) return model, transformed diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 2959a8d0de..1c15f24815 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -747,6 +747,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index 54b146b36b..eb58eba8d9 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1636,7 +1636,7 @@ def get_specializations( for h, w, f in zip(height, width, num_frames): resized_height, resized_width = smart_resize( - height=h, width=w, factor=image_factor, min_pixels=min_pixels, max_pixels=max_pixels + height=h, width=w, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels ) grid_h, grid_w = resized_height // patch_size, resized_width // patch_size grid_height = grid_h * grid_w diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 08d90dbaf8..fc49c1fb43 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -653,6 +653,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 3a6f23c2cc..1a04d0629b 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -761,6 +761,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/utils/config_utils.py b/QEfficient/utils/config_utils.py new file mode 100644 index 0000000000..4b28d54880 --- /dev/null +++ b/QEfficient/utils/config_utils.py @@ -0,0 +1,68 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Iterable, Optional + +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, HIDDEN_SIZE_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS + + +def get_first_config_value(config, names: Iterable[str], default=None, cast_int: bool = False): + for name in names: + value = getattr(config, name, None) + if value is not None: + return int(value) if cast_int else value + return default + + +def resolve_attention_heads(config) -> Optional[int]: + return get_first_config_value(config, ATTENTION_HEAD_CONFIG_KEYS, cast_int=True) + + +def resolve_kv_heads(config) -> Optional[int]: + value = get_first_config_value(config, KV_HEAD_CONFIG_KEYS, cast_int=True) + if value is None: + value = resolve_attention_heads(config) + return value + + +def resolve_hidden_size(config) -> Optional[int]: + return get_first_config_value(config, HIDDEN_SIZE_CONFIG_KEYS, cast_int=True) + + +def set_kv_head_aliases(config, value: int): + setattr(config, "num_key_value_heads", value) + for key in KV_HEAD_CONFIG_KEYS: + if hasattr(config, key): + setattr(config, key, value) + + +def calculate_num_replicate_kv_heads(num_devices: int, text_model_config) -> int: + """ + Choose a KV-repeat value from model config and device count. + + Primary criteria: + 1. num_kv_heads * repeat is divisible by num_devices + 2. num_attention_heads is divisible by (num_kv_heads * repeat) + + Fallback: + repeat = num_attention_heads / num_kv_heads (integer-truncated if needed). + """ + num_attention_heads = resolve_attention_heads(text_model_config) + num_kv_heads = resolve_kv_heads(text_model_config) + + if num_attention_heads is None or num_kv_heads is None or num_attention_heads < 1 or num_kv_heads < 1: + return 1 + + num_devices = max(1, int(num_devices)) + max_repeat = max(1, int(num_attention_heads / num_kv_heads)) + + for repeat in range(max_repeat, 0, -1): + repeated_kv_heads = num_kv_heads * repeat + if (repeated_kv_heads % num_devices == 0) and (num_attention_heads % repeated_kv_heads == 0): + return repeat + + return 1 diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 4e1032a40f..baf6a27422 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -140,6 +140,11 @@ def get_default_aic_hw_version() -> str: DEFAULT_AIC_HW_VERSION = get_default_aic_hw_version() ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 +# Generic config key aliases used across model families. +ATTENTION_HEAD_CONFIG_KEYS = ("num_attention_heads", "n_head", "n_heads", "num_heads") +KV_HEAD_CONFIG_KEYS = ("num_key_value_heads", "n_kv_heads", "num_kv_heads", "effective_n_kv_heads") +HIDDEN_SIZE_CONFIG_KEYS = ("hidden_size", "n_embd", "d_model") + # InternVL constants # Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B INTERN_FEATURE_SIZE = 256 diff --git a/QEfficient/utils/repeat_kv_utils.py b/QEfficient/utils/repeat_kv_utils.py new file mode 100644 index 0000000000..8f7251436e --- /dev/null +++ b/QEfficient/utils/repeat_kv_utils.py @@ -0,0 +1,270 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Optional, Sequence + +import torch +import torch.nn as nn + +from QEfficient.customop.matmulnbits import QuantLinearORT, dequantize_blockwise_bits +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear +from QEfficient.utils.config_utils import resolve_attention_heads, resolve_hidden_size, resolve_kv_heads + +TEXT_MODEL_CANDIDATE_PATHS = ( + ("language_model",), + ("language_model", "model"), + ("model", "language_model"), + ("model", "language_model", "model"), + ("model", "model", "language_model"), + ("model", "model", "language_model", "model"), + ("model",), + ("model", "model"), + ("transformer",), + ("transformer", "model"), + ("llm",), + ("llm", "model"), + ("backbone",), +) + + +def duplicate_kv_projection_weights( + layer: nn.Module, + orig_kv_heads: int, + repeat: int, + head_dim: int, + hidden_size: int, + layer_name: Optional[str] = None, +) -> None: + """ + Duplicate KV projection weights for one projection layer to implement repeat-kv transform. + """ + layer_prefix = f"{layer_name}: " if layer_name else "" + new_kv_heads = repeat * orig_kv_heads + + if isinstance(layer, WQLinear_GEMM): + if layer.qweight.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"{layer_prefix}Invalid AWQ qweight shape for RepeatKV: qweight.shape={tuple(layer.qweight.shape)}, " + f"orig_kv_heads={orig_kv_heads}" + ) + if layer.qzeros.shape[1] % orig_kv_heads != 0 or layer.scales.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"{layer_prefix}Invalid AWQ qzeros/scales shape for RepeatKV: qzeros.shape={tuple(layer.qzeros.shape)}, " + f"scales.shape={tuple(layer.scales.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(layer.qweight.shape[0], orig_kv_heads, -1), repeat, dim=1 + ).view(layer.qweight.shape[0], -1) + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(layer.qzeros.shape[0], orig_kv_heads, -1), repeat, dim=1 + ).view(layer.qzeros.shape[0], -1) + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(layer.scales.shape[0], orig_kv_heads, -1), repeat, dim=1 + ).view(layer.scales.shape[0], -1) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, QuantLinearGPTQ): + if layer.qweight.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"{layer_prefix}Invalid GPTQ qweight shape for RepeatKV: qweight.shape={tuple(layer.qweight.shape)}, " + f"orig_kv_heads={orig_kv_heads}" + ) + if layer.qzeros.shape[1] % orig_kv_heads != 0 or layer.scales.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"{layer_prefix}Invalid GPTQ qzeros/scales shape for RepeatKV: qzeros.shape={tuple(layer.qzeros.shape)}, " + f"scales.shape={tuple(layer.scales.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(layer.qweight.shape[0], orig_kv_heads, -1), repeat, dim=1 + ).view(layer.qweight.shape[0], -1) + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(layer.qzeros.shape[0], orig_kv_heads, -1), repeat, dim=1 + ).view(layer.qzeros.shape[0], -1) + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(layer.scales.shape[0], orig_kv_heads, -1), repeat, dim=1 + ).view(layer.scales.shape[0], -1) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, QuantLinearORT): + float_weight, zeros_per_group, scales_per_group = dequantize_blockwise_bits( + layer.qweight, + layer.scales, + layer.qzeros, + layer.bits, + layer.group_size, + layer.g_idx, + layer.in_features, + layer.out_features, + ) + if float_weight.shape[0] % orig_kv_heads != 0: + raise ValueError( + f"{layer_prefix}Invalid QuantLinearORT weight shape for RepeatKV: " + f"weight.shape={tuple(float_weight.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + duplicated_weight = torch.repeat_interleave( + float_weight.view(orig_kv_heads, -1, float_weight.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (float_weight.shape[0] // orig_kv_heads), float_weight.shape[1]) + + duplicated_zeros = torch.repeat_interleave( + zeros_per_group.view(orig_kv_heads, -1, zeros_per_group.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (zeros_per_group.shape[0] // orig_kv_heads), zeros_per_group.shape[1]) + duplicated_scales = torch.repeat_interleave( + scales_per_group.view(orig_kv_heads, -1, scales_per_group.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (scales_per_group.shape[0] // orig_kv_heads), scales_per_group.shape[1]) + + original_out_features = layer.out_features + layer.out_features = original_out_features * repeat + q_rows = layer.in_features // layer.group_size + layer.qweight = torch.zeros( + (layer.out_features, q_rows, layer.group_size // (8 // layer.bits)), + dtype=layer.qweight.dtype, + device=layer.qweight.device, + ) + layer.qzeros = torch.zeros( + (q_rows + (q_rows & 1)) * (layer.out_features // 8 * layer.bits), + dtype=layer.qzeros.dtype, + device=layer.qzeros.device, + ) + layer.scales = torch.zeros( + (q_rows * layer.out_features), + dtype=layer.scales.dtype, + device=layer.scales.device, + ) + + linear = nn.Linear(layer.in_features, layer.out_features, bias=False, dtype=duplicated_weight.dtype) + linear.weight.data = duplicated_weight.to(linear.weight.dtype) + layer.pack( + linear, + duplicated_scales.contiguous().to(layer.scales.dtype), + duplicated_zeros.contiguous().to(torch.int32), + layer.g_idx, + ) + + elif isinstance(layer, FP8DeQuantLinear): + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), + repeat, + dim=0, + ).view(new_kv_heads * head_dim, hidden_size) + layer.weight_scale.data = torch.repeat_interleave( + layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, dim=0 + ).view(new_kv_heads * head_dim, -1) + + else: + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), + repeat, + dim=0, + ).view(new_kv_heads * head_dim, hidden_size) + + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, head_dim), + repeat, + dim=0, + ).view(new_kv_heads * head_dim) + + +def get_attention_module(block: nn.Module) -> nn.Module: + for attr in ("cross_attn", "self_attn", "attention", "attn"): + attn = getattr(block, attr, None) + if attn is not None: + return attn + raise AttributeError(f"No attention module found in block type {block.__class__.__name__}") + + +def get_projection_layer(attn: nn.Module, names: Sequence[str]) -> nn.Module: + for name in names: + layer = getattr(attn, name, None) + if layer is not None: + return layer + raise AttributeError(f"Missing projection layer in {attn.__class__.__name__}; expected one of {tuple(names)}") + + +def is_mla_attention(attn: nn.Module) -> bool: + return hasattr(attn, "kv_a_proj_with_mqa") and hasattr(attn, "kv_lora_rank") and hasattr(attn, "qk_rope_head_dim") + + +def is_mla_model(text_model: nn.Module) -> bool: + for block in getattr(text_model, "layers", []): + try: + attn = get_attention_module(block) + except AttributeError: + continue + if is_mla_attention(attn): + return True + return False + + +def is_valid_text_model(candidate: nn.Module) -> bool: + if candidate is None: + return False + cfg = getattr(candidate, "config", None) + layers = getattr(candidate, "layers", None) + attn_heads = resolve_attention_heads(cfg) if cfg is not None else None + kv_heads = resolve_kv_heads(cfg) if cfg is not None else None + hidden_size = resolve_hidden_size(cfg) if cfg is not None else None + return ( + cfg is not None + and layers is not None + and attn_heads is not None + and kv_heads is not None + and hidden_size is not None + ) + + +def get_text_model(model: nn.Module) -> nn.Module: + for path in TEXT_MODEL_CANDIDATE_PATHS: + candidate = model + valid_path = True + for attr in path: + if not hasattr(candidate, attr): + valid_path = False + break + candidate = getattr(candidate, attr) + if valid_path and is_valid_text_model(candidate): + return candidate + + raise AttributeError( + f"No suitable text model found in the provided model ({model.__class__.__name__}). " + "Expected a module with `layers` and text `config` attributes." + ) + + +def get_replication_root(model: nn.Module) -> nn.Module: + candidate = getattr(model, "model", None) + return candidate if isinstance(candidate, nn.Module) else model + + +def replication_targets(model: nn.Module, text_model: Optional[nn.Module] = None): + targets = [] + root = get_replication_root(model) + if root is not None: + targets.append(root) + if text_model is not None: + targets.append(text_model) + cfg = getattr(text_model, "config", None) + if cfg is not None: + targets.append(cfg) + return targets + + +def is_replication_applied(model: nn.Module, text_model: Optional[nn.Module] = None) -> bool: + return any( + getattr(target, "_qeff_kv_replication_applied", False) for target in replication_targets(model, text_model) + ) diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 131ff59e26..51ebffba55 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -289,6 +289,14 @@ def load_qeff_model_with_sampler( return qeff_model +def get_text_config(config): + if hasattr(config, "text_config"): + return config.text_config + elif hasattr(config, "llm_config"): + return config.llm_config + return config + + # Processor class for InternVL models class InternProcessor: """ @@ -492,6 +500,27 @@ class ModelConfig: "Qwen/Qwen3.6-35B-A3B", } + REPEAT_KV_TEST_MODELS = { + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "ibm-granite/granite-3.1-1b-a400m-base", + "Qwen/Qwen2-0.5B", + "bigcode/starcoder2-3b", + "meta-llama/Llama-3.2-1B", + "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + "TheBloke/Llama-2-7B-GPTQ", + "neuralmagic/Llama-3.2-3B-Instruct-FP8", + "ibm-granite/granite-3.1-2b-instruct", + "llava-hf/llava-1.5-7b-hf", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen3-VL-2B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", + "allenai/Molmo-7B-D-0924", + "OpenGVLab/InternVL2_5-1B", + "Qwen/Qwen3.5-0.8B", + } + EXTERNAL_MODELS = { "hpcai-tech/grok-1": { "pytorch_hf_tokens_custom_case": [ diff --git a/dbg.log b/dbg.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/kimi_k2/README.md b/examples/kimi_k2/README.md index 230127ebbe..4fae4a8cfb 100644 --- a/examples/kimi_k2/README.md +++ b/examples/kimi_k2/README.md @@ -20,9 +20,9 @@ mla_absorption has 3 keys: # Blocking We have also implemented KV head replication, HEAD Blocking and KV Blocking which can be enable like this : - For No Blocking : qaic_config = {"mla_absorption" : mla_absorption} -- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_kv_heads_repeat": TS} +- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_replicate_kv_heads": TS} - For KV blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_replicate_kv_heads": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads - Currently Decode-Only model is giving best perf with Head Blocking and compressed cache. - Contnuous batching is not enabled yet. \ No newline at end of file diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 1e70352165..ba6b26c064 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -18,16 +18,16 @@ # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking # qaic_config = {"mla_absorption": mla_absorption} # for No Blocking -# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "num_replicate_kv_heads": TS} # No blocking with kv head replication # qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_kv_heads_repeat":TS} # for KV blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_replicate_kv_heads":TS} # for KV blocking with kv head replication qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", - "num_kv_heads_repeat": TS, + "num_replicate_kv_heads": TS, } -# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( diff --git a/examples/text_generation/run_kimik2.py b/examples/text_generation/run_kimik2.py index 81767308ad..e85c572420 100644 --- a/examples/text_generation/run_kimik2.py +++ b/examples/text_generation/run_kimik2.py @@ -19,16 +19,16 @@ # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking # qaic_config = {"mla_absorption": mla_absorption} # for No Blocking -# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "num_replicate_kv_heads": TS} # No blocking with kv head replication # qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_kv_heads_repeat":TS} # for KV blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_replicate_kv_heads":TS} # for KV blocking with kv head replication qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", - "num_kv_heads_repeat": TS, + "num_replicate_kv_heads": TS, } -# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( diff --git a/tests/configs/causal_model_configs.json b/tests/configs/causal_model_configs.json index 93f4e7ae2f..2c092ed9ee 100644 --- a/tests/configs/causal_model_configs.json +++ b/tests/configs/causal_model_configs.json @@ -325,6 +325,19 @@ "num_key_value_heads": 1 } }, + { + "model_name": "hpcai-tech/grok-1", + "model_type": null, + "additional_params": { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "hidden_size": 64, + "intermediate_size": 256, + "vocab_size": 131072, + "num_key_value_heads": 1 + } + }, { "model_name": "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", "model_type": null, @@ -720,4 +733,4 @@ } } ] -} +} \ No newline at end of file diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index f878acbe73..78ff74cbfd 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -16,7 +16,8 @@ from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers from QEfficient.utils._utils import load_hf_tokenizer -from QEfficient.utils.constants import Constants +from QEfficient.utils.config_utils import get_first_config_value +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS, Constants from QEfficient.utils.run_utils import ApiRunner from QEfficient.utils.test_utils import ModelConfig, load_hf_causal_lm_model @@ -39,6 +40,52 @@ def get_custom_n_layers(model_name): return 1 +def check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name: str, + manual_cleanup: callable, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = -1, + config: Optional[AutoConfig] = None, +): + """ + Validate causal LM flow with repeated KV heads configuration. + """ + if config is None: + model_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: + model_config = config + + num_attention_heads = get_first_config_value(model_config, ATTENTION_HEAD_CONFIG_KEYS, default=1, cast_int=True) + num_key_value_heads = get_first_config_value(model_config, KV_HEAD_CONFIG_KEYS, default=None, cast_int=True) + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + if num_attention_heads < 1 or num_key_value_heads < 1: + raise ValueError( + f"Invalid heads in config for RepeatKV: " + f"num_attention_heads={num_attention_heads}, num_key_value_heads={num_key_value_heads}" + ) + if num_attention_heads % num_key_value_heads != 0: + raise ValueError( + f"Invalid heads in config for RepeatKV: num_attention_heads ({num_attention_heads}) " + f"is not divisible by num_key_value_heads ({num_key_value_heads})." + ) + num_replicate_kv_heads = num_attention_heads // num_key_value_heads + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + manual_cleanup=manual_cleanup, + prompt_len=prompt_len, + ctx_len=ctx_len, + n_layer=n_layer, + config=config, + qaic_config={"num_replicate_kv_heads": num_replicate_kv_heads}, + ) + + def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, manual_cleanup: callable, @@ -71,15 +118,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = None ort_tokens = None - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - prompts, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - full_batch_size if continuous_batching else None, - ) qeff_model = QEFFAutoModelForCausalLM( copy.deepcopy(model_hf), is_tlm=is_tlm, @@ -94,6 +132,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( num_devices=num_devices, qaic_config=qaic_config, ) + api_runner = ApiRunner( + batch_size, + tokenizer, + qeff_model.config, + prompts, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + full_batch_size if continuous_batching else None, + ) if continuous_batching is False: pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py index 5011a670a6..4bfc750adf 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py @@ -17,6 +17,7 @@ from .check_causal_models import ( check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100, + check_kv_repeat_causal_lm_pytorch_vs_ai100, get_custom_n_layers, ) @@ -73,6 +74,34 @@ def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanu check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=hf_config, manual_cleanup=manual_cleanup) +@pytest.mark.dummy_layers +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_causal) +def test_check_kv_repeat_custom_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup): + """ + Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name in ModelConfig.SKIPPED_MODELS: + pytest.skip("Test skipped for this model due to issues in HF.") + custom_config = model_config_dict[model_name] + hf_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + **custom_config.get("additional_params", {}), + ) + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + if model_name in ModelConfig.QUANTIZED_MODELS: + n_layer = get_custom_n_layers(model_name) + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup=manual_cleanup, n_layer=n_layer) + else: + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup=manual_cleanup, config=hf_config) + else: + pytest.skip(f"Skipping {model_name} as it is not in REPEAT_KV_TEST_MODELS") + + @pytest.mark.full_layers @pytest.mark.on_qaic @pytest.mark.llm_model diff --git a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index 9b9e662e52..df9c3b9e8d 100644 --- a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -30,6 +30,7 @@ from QEfficient.utils.test_utils import ( InternProcessor, ModelConfig, + get_text_config, load_vlm_model, load_vlm_model_from_config, set_num_layers_vlm, @@ -56,6 +57,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, + qaic_config: Optional[dict] = None, + num_replicate_kv_heads: Optional[int] = 1, + test_kv_replicate: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = torch.float32, compare_results: Optional[bool] = False, ): @@ -70,11 +74,17 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = None ort_tokens = None n_layer = num_hidden_layers + qaic_config = copy.deepcopy(qaic_config) if qaic_config is not None else None if config is None: config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, padding=model_name not in ModelConfig.MOLMO_MODELS ) config = set_num_layers_vlm(config, n_layer=n_layer) + if test_kv_replicate: + text_config = get_text_config(config) + num_replicate_kv_heads = text_config.num_attention_heads // text_config.num_key_value_heads + qaic_config = qaic_config or {} + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads if hasattr(config, "model_type") and config.model_type in ["gemma3"]: config.text_config._sliding_window_pattern = 2 config.text_config.layer_types = ["sliding_attention", "full_attention"] @@ -92,7 +102,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload=kv_offload, config=config, + qaic_config=qaic_config, torch_dtype=torch_dtype, + num_replicate_kv_heads=num_replicate_kv_heads, ) else: model_hf = load_vlm_model(config) @@ -100,15 +112,24 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload=kv_offload, config=config, + qaic_config=qaic_config, torch_dtype=torch_dtype, + num_replicate_kv_heads=num_replicate_kv_heads, ) else: + if test_kv_replicate: + text_config = get_text_config(config) + num_replicate_kv_heads = text_config.num_attention_heads // text_config.num_key_value_heads + qaic_config = qaic_config or {} + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads model_hf = load_vlm_model_from_config(config) qeff_model = QEFFAutoModelForImageTextToText( copy.deepcopy(model_hf), kv_offload=kv_offload, config=model_hf.config, + qaic_config=qaic_config, torch_dtype=torch_dtype, + num_replicate_kv_heads=num_replicate_kv_heads, ) compile_kwargs = { "num_devices": num_devices, @@ -117,6 +138,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( "mxfp6": False, "enable_qnn": enable_qnn, "qnn_config": qnn_config, + "qaic_config": qaic_config, } if model_name in ModelConfig.INTERNVL_MODELS: @@ -239,7 +261,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( # "Tokens don't match for pytorch HF output and pytorch KV output" # ) - _ = qeff_model.export() + # _ = qeff_model.export() # ort_tokens = api_runner.run_vlm_kv_model_on_ort(onnx_model_path) # assert (pytorch_hf_tokens == ort_tokens).all(), "Tokens don't match for pytorch HF output and ORT output" @@ -337,6 +359,57 @@ def test_dummy_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_o ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.dummy_layers +@pytest.mark.parametrize("model_name", test_mm_models) +@pytest.mark.parametrize("kv_offload", [True, False]) +def test_custom_replicate_kv_pytorch_vs_ai100( + model_name, + kv_offload, + manual_cleanup, +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + torch.manual_seed(42) + if model_name in ModelConfig.SKIPPED_MODELS: + pytest.skip("Test skipped for this model due to some issues.") + if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: + pytest.skip("These models require kv_offload=True for testing.") + + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + hf_config = None + if model_name in ModelConfig.STANDARD_VLM_MODELS: + model_type = model_config_dict[model_name].get("model_type") + custom_config = model_config_dict[model_name].get("additional_params", {}) + hf_config = AutoConfig.for_model(model_type, trust_remote_code=True, **custom_config) + hf_config.name_or_path = model_name + + if hf_config is not None: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + kv_offload=kv_offload, + config=hf_config, + qaic_config={}, + test_kv_replicate=True, + manual_cleanup=manual_cleanup, + ) + else: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + num_hidden_layers=model_config_dict[model_name]["num_layers"], + kv_offload=kv_offload, + qaic_config={}, + test_kv_replicate=True, + manual_cleanup=manual_cleanup, + ) + else: + pytest.skip(f"Skipping replicate KV test for {model_name} as it's not in REPEAT_KV_TEST_MODELS") + + ################################ QNN Tests ################################ From ab8813b0346e1f1b98843a082b798898c180d8ac Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Thu, 11 Jun 2026 16:01:20 +0530 Subject: [PATCH 2/2] Minor fix for rebased code. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/transformers/models/pytorch_transforms.py | 10 ---------- .../transformers/models/qwen3_5/modeling_qwen3_5.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 20779c9dac..fec8582f90 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -620,16 +620,6 @@ QEffT5LayerNorm, QEffT5Stack, ) - -# from QEfficient.transformers.models.repeat_kv_utils import ( -# duplicate_kv_projection_weights, -# get_attention_module, -# get_projection_layer, -# get_text_model, -# is_mla_model, -# is_replication_applied, -# replication_targets, -# ) from QEfficient.transformers.models.wav2vec2.modeling_wav2vec2 import ( QEffWav2Vec2Encoder, QEffWav2Vec2EncoderStableLayerNorm, diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index eb58eba8d9..54b146b36b 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1636,7 +1636,7 @@ def get_specializations( for h, w, f in zip(height, width, num_frames): resized_height, resized_width = smart_resize( - height=h, width=w, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels + height=h, width=w, factor=image_factor, min_pixels=min_pixels, max_pixels=max_pixels ) grid_h, grid_w = resized_height // patch_size, resized_width // patch_size grid_height = grid_h * grid_w