Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
require_value,
to_named_specializations,
)
from QEfficient.utils.config_utils import calculate_num_replicate_kv_heads
from QEfficient.utils.export_utils import export_wrapper

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -735,23 +736,48 @@ def transform(
**compiler_options,
):
# Apply the transformations that are dependent on compilation parameters
def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module:
"""
Use the shared wrapped model as transform-tracking root when available.
This lets encoder/decoder wrappers coordinate one-time transforms.
"""
wrapped = getattr(module, "model", None)
return wrapped if isinstance(wrapped, torch.nn.Module) else module

qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None)

model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None)
model_config = getattr(self.model, "config", None) or getattr(
getattr(self.model, "model", None), "config", None
)
num_replicate_kv_heads = 1
if model_config is not None:
num_replicate_kv_heads = calculate_num_replicate_kv_heads(
num_devices=num_devices,
text_model_config=model_config,
)

if model_config:
if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []):
if qaic_config:
if qaic_config.get("blocking_mode", None) == "h":
qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices)
num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1)
if qaic_config is not None:
num_replicate_kv_heads = qaic_config.get("num_replicate_kv_heads", num_replicate_kv_heads)
qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads
transform_root = _transform_tracking_root(self.model)
applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set())
should_apply_repeat_kv = num_replicate_kv_heads is not None and num_replicate_kv_heads > 1
if not should_apply_repeat_kv:
replicate_kv_transformed = False
elif ReplicateKVHeadTransform.__name__ in applied_transforms:
replicate_kv_transformed = False
logger.warning("Skipping RepeatKVTransform: already applied on this model instance.")
else:
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(
self.model, num_kv_heads_repeat
self.model,
num_replicate_kv_heads,
)
if replicate_kv_transformed:
self.hash_params["config"] = self.model.config.to_diff_dict()

applied_transforms.add(ReplicateKVHeadTransform.__name__)
setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms)
if replicate_kv_transformed:
self.hash_params["config"] = self.model.config.to_diff_dict()
blocking_config = build_transformer_blocking_config_for_transform(
model_config,
ctx_len=ctx_len,
Expand All @@ -768,6 +794,7 @@ def transform(
if blocking_config is not None:
self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config)
self.hash_params["blocking_kwargs"] = blocking_config
self.hash_params["num_replicate_kv_heads"] = num_replicate_kv_heads

@dump_qconfig
def _compile(
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ class QEffGemma3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model.model
self.config = self.model.config
self.model.vision_model = self.model.vision_tower

def get_submodules_for_export(self) -> Type[nn.Module]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class QEffInternEncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ class QEffLlama4EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.model.vision_tower
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.model.vision_tower
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class QEFFMistral3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config
self.model.model.vision_model = self.model.model.vision_tower

def get_submodules_for_export(self) -> Type[nn.Module]:
Expand Down
16 changes: 15 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
)

_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -1469,6 +1470,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)

Expand Down Expand Up @@ -1593,7 +1595,12 @@ def export(
if prefill_only and prefill_seq_len > 1:
offload_pt_weights = False # to keep weight for decode onnx
else:
offload_pt_weights = kwargs.get("offload_pt_weights", True)
num_replicate_kv_heads = (
(self.lang_model.model.qaic_config or {}).get("num_replicate_kv_heads", 1)
if hasattr(self.lang_model.model, "qaic_config")
else 1
)
offload_pt_weights = kwargs.get("offload_pt_weights", num_replicate_kv_heads <= 1)

if not skip_lang:
self.lang_model.export(
Expand Down Expand Up @@ -2616,6 +2623,7 @@ def from_pretrained(
config._attn_implementation = "eager"
config.vision_config.use_flash_attn = "false"
_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -2624,6 +2632,7 @@ def from_pretrained(
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)

Expand Down Expand Up @@ -3275,6 +3284,7 @@ def from_pretrained(
)

_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
if layerwise:
# Layer-wise mode: build the outer model on the meta device so the
# caller's ``from_pretrained`` does not pull the full checkpoint
Expand All @@ -3293,6 +3303,7 @@ def from_pretrained(
continuous_batching=continuous_batching,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)
# Mark the wrapper so its compile() can default ``layerwise=True`` if
Expand Down Expand Up @@ -3550,6 +3561,7 @@ def from_pretrained(
)

_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
if layerwise:
# Layer-wise mode: build the outer model on the meta device. The
# caller still gets a typed wrapper, but no checkpoint weights are
Expand All @@ -3570,6 +3582,7 @@ def from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
continuous_batching=continuous_batching,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)
instance = cls(
Expand All @@ -3578,6 +3591,7 @@ def from_pretrained(
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
max_seq_len_cached=max_seq_len_cached,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)
if layerwise:
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/molmo/modeling_molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ class QEffMolmoEncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Loading
Loading