diff --git a/.github/workflows/prek.yml b/.github/workflows/prek.yml index b2550f761..15ab4264f 100644 --- a/.github/workflows/prek.yml +++ b/.github/workflows/prek.yml @@ -202,7 +202,7 @@ jobs: done < "${part_paths_file}" | zstd -d -c | tar -xf - -C "${UV_CACHE_DIR}" du -sh "${UV_CACHE_DIR}" - - name: Install dependencies (with all optional extras for complete type checking) + - name: Install Megatron dependencies run: | original_pyproject="$(mktemp)" cp pyproject.toml "${original_pyproject}" @@ -229,12 +229,31 @@ jobs: --apex-nvcc-threads "${CI_APEX_NVCC_THREADS}" echo "CI uv build overrides: APEX_PARALLEL_BUILD=${CI_APEX_PARALLEL_BUILD}, NVCC_APPEND_FLAGS=--threads ${CI_APEX_NVCC_THREADS}, UV_CONCURRENT_BUILDS=${CI_UV_BUILD_SLOTS}" uv --version - uv sync --all-extras --group dev --frozen --python "${CI_PYTHON_MM}" + uv sync --extra megatron --extra langgraph --extra plotting --group dev --frozen --python "${CI_PYTHON_MM}" - - name: Run prek hooks (lint, format, typecheck, uv.lock, tests) + - name: Run prek hooks (lint, format, typecheck, uv.lock) run: | - uv run --no-sync prek run --all-files + uv run --no-sync prek run ruff --all-files + uv run --no-sync prek run ruff-format --all-files + uv run --no-sync prek run ty --all-files + uv run --no-sync prek run uv-lock-check --all-files - - name: Run unit tests (via prek) + - name: Run Megatron unit tests run: | - uv run --no-sync prek run pytest + uv run --no-sync pytest --nbval --current-env --tb=short \ + tests/unit/test_megatron_reference_logprobs.py \ + tests/unit/test_moe_routing_replay.py \ + tests/unit/test_moe_routing_real_path.py \ + tests/unit/test_pipeline_trainer_local_backend.py + + - name: Install backend dependencies + run: | + uv sync --extra backend --extra tinker --extra langgraph --extra plotting --group dev --frozen --python "${CI_PYTHON_MM}" + + - name: Run unit tests + run: | + uv run --no-sync pytest --nbval --current-env --tb=short tests/unit \ + --ignore=tests/unit/test_megatron_reference_logprobs.py \ + --ignore=tests/unit/test_moe_routing_replay.py \ + --ignore=tests/unit/test_moe_routing_real_path.py \ + --ignore=tests/unit/test_pipeline_trainer_local_backend.py diff --git a/pyproject.toml b/pyproject.toml index bfa06e5d1..a9d1197df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ backend = [ "bitsandbytes>=0.45.2", "unsloth==2026.3.3", "unsloth-zoo==2026.3.1", - "torch>=2.11.0", + "torch==2.11.0", "torchao==0.16.0", "accelerate==1.7.0", "awscli>=1.38.1", @@ -43,7 +43,7 @@ backend = [ ] megatron = [ "numpy<2", - "torch>=2.11.0", + "torch==2.11.0", "flash-attn-4==4.0.0b5", "ninja>=1.11.1", "quack-kernels==0.3.7", @@ -61,6 +61,7 @@ megatron = [ "nvidia-ml-py==13.580.82", "nvidia-modelopt>=0.42.0a0 ; sys_platform != 'darwin'", "nvidia-resiliency-ext<0.5 ; sys_platform == 'linux'", + "transformers==5.6.2", "ml-dtypes>=0.5.0 ; python_full_version < '3.13'", ] @@ -79,8 +80,8 @@ tinker = [ "protobuf>=6.31.1", "tinker-cookbook>=0.4.1,<0.5", "tinker>=0.21.0,<0.22", - "torch>=2.11.0", - "transformers==5.2.0", + "torch==2.11.0", + "transformers>=5.2.0,<=5.5.3", "uvicorn>=0.35.0", "datrie>=0.8.3", ] @@ -150,14 +151,23 @@ markers = [ [tool.uv] required-version = ">=0.11.7" +conflicts = [ + [ + { extra = "backend" }, + { extra = "megatron" }, + ], + [ + { extra = "tinker" }, + { extra = "megatron" }, + ], +] override-dependencies = [ - "flashinfer-python==0.6.1", + "flashinfer-python==0.6.8.post1", "megatron-core==0.17.0", "numpy<2", "nvidia-resiliency-ext<0.5", "quack-kernels==0.3.7", "transformer-engine==2.11.0", - "transformers==5.2.0", "torch==2.11.0", ] exclude-dependencies = ["pynvml", "emerging-optimizers", "causal-conv1d", "mamba-ssm"] @@ -184,6 +194,46 @@ name = "deep-ep" version = "1.2.1+9af0e0d" requires-dist = [] +# The Megatron Bridge source metadata currently requires Transformers 5.8.x, +# but this branch is validated against Transformers 5.6.2 for Gemma 4. +# Keep Bridge's runtime deps explicit here and let ART's megatron extra own the +# Transformers pin. +[[tool.uv.dependency-metadata]] +name = "megatron-bridge" +version = "0.5.0+e1a207ac" +requires-dist = [ + "accelerate", + "comet-ml", + "datasets", + "diffusers", + "einops", + "flash-linear-attention", + "flashinfer-cubin", + "flashinfer-python", + "hydra-core", + "imageio", + "imageio-ffmpeg", + "megatron-core", + "mistral-common", + "mlflow", + "nvidia-resiliency-ext", + "omegaconf", + "open-clip-torch", + "peft", + "pyyaml", + "qwen-vl-utils", + "regex", + "rich", + "six", + "tensorboard", + "timm", + "torch", + "tqdm", + "transformers", + "typing-extensions", + "wandb", +] + [[tool.uv.dependency-metadata]] name = "transformer-engine-torch" version = "2.11.0" @@ -276,7 +326,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_pla apex = { git = "https://github.com/NVIDIA/apex.git", rev = "25.09" } deep-ep = { git = "https://github.com/deepseek-ai/DeepEP.git", rev = "v1.2.1" } flash-attn-4 = { url = "https://files.pythonhosted.org/packages/24/f7/01ee2576ce41f9884d291ee21861ef194afc0b2b1ce3bd175fc7a6e1b133/flash_attn_4-4.0.0b5-py3-none-any.whl" } -megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "e049cc00c24d03e2ae45d2608c7a44e2d2364e3d" } +megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "e1a207ac757e5d0ed94d8ffbe1cbd28e81d8c084" } panza = { git = "https://github.com/corbt/panza.git" } transformer-engine-torch = { git = "https://github.com/NVIDIA/TransformerEngine.git", rev = "v2.11", subdirectory = "transformer_engine/pytorch" } diff --git a/scripts/ci/build_and_push_uv_cache.sh b/scripts/ci/build_and_push_uv_cache.sh index f98db5f3e..5e7535a66 100755 --- a/scripts/ci/build_and_push_uv_cache.sh +++ b/scripts/ci/build_and_push_uv_cache.sh @@ -283,8 +283,9 @@ build_cache_archive() { export LIBRARY_PATH="${CUDNN_LIBRARY_PATH}${LIBRARY_PATH:+:${LIBRARY_PATH}}" export LD_LIBRARY_PATH="${CUDNN_LIBRARY_PATH}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" - log "Building full uv cache with compile_jobs=${compile_jobs}, apex_parallel_build=${apex_parallel_build}, nvcc_threads=${CI_APEX_NVCC_THREADS}, cuda_arch_list=${TORCH_CUDA_ARCH_LIST}, and uv_concurrent_builds=${UV_BUILD_SLOTS}." - uv sync --frozen --all-extras --group dev --no-install-project --python "${PYTHON_MM}" + log "Building split uv cache with compile_jobs=${compile_jobs}, apex_parallel_build=${apex_parallel_build}, nvcc_threads=${CI_APEX_NVCC_THREADS}, cuda_arch_list=${TORCH_CUDA_ARCH_LIST}, and uv_concurrent_builds=${UV_BUILD_SLOTS}." + uv sync --frozen --extra megatron --extra langgraph --extra plotting --group dev --no-install-project --python "${PYTHON_MM}" + uv sync --frozen --extra backend --extra tinker --extra langgraph --extra plotting --group dev --no-install-project --python "${PYTHON_MM}" rm -rf .venv log "Packing uv cache archive to ${archive_path}." diff --git a/scripts/ci/compute_uv_fingerprint.py b/scripts/ci/compute_uv_fingerprint.py index f9029edf5..a200251c0 100755 --- a/scripts/ci/compute_uv_fingerprint.py +++ b/scripts/ci/compute_uv_fingerprint.py @@ -83,9 +83,9 @@ def main() -> int: "uv_lock_sha256": _sha256_file(args.uv_lock), }, "ci_context": { - "fingerprint_schema_version": 9, + "fingerprint_schema_version": 10, "cache_kind": "full_uv_cache", - "cache_scope": "prek_all_extras_group_dev", + "cache_scope": "prek_split_extras_group_dev", "cache_target": "uv_cache", "cache_python_platform": "linux_x86_64", "cache_package_manager": "uv", diff --git a/scripts/setup.sh b/scripts/setup.sh index 76936fcd9..cc34695f3 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -72,7 +72,7 @@ fi # Sync the dependencies if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then - uv sync --all-extras --frozen + uv sync --extra backend --extra tinker --extra langgraph --extra plotting --frozen else uv sync --extra backend --frozen fi diff --git a/src/art/__init__.py b/src/art/__init__.py index 675500c23..ddb0d8c1f 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -64,11 +64,16 @@ from .batches import trajectory_group_batches from .dev import LoRAConfig from .gather import gather_trajectories, gather_trajectory_groups +from .megatron.runtime_config import ( + get_megatron_runtime_config, + init_megatron_runtime_config, +) from .model import Model, TrainableModel from .serverless import ServerlessBackend from .trajectories import Trajectory, TrajectoryGroup from .types import ( LocalTrainResult, + MegatronRuntimeConfig, MegatronTopologyConfig, Messages, MessagesAndChoices, @@ -91,7 +96,10 @@ "Backend", "LocalTrainResult", "LoRAConfig", + "MegatronRuntimeConfig", "MegatronTopologyConfig", + "get_megatron_runtime_config", + "init_megatron_runtime_config", "ServerlessBackend", "ServerlessTrainResult", "Messages", diff --git a/src/art/_backend_training.py b/src/art/_backend_training.py index a2feb8133..97ce4b079 100644 --- a/src/art/_backend_training.py +++ b/src/art/_backend_training.py @@ -9,7 +9,7 @@ summarize_trajectory_groups, ) from .trajectories import TrajectoryGroup -from .types import MegatronTopologyConfig, TrainConfig +from .types import TrainConfig def build_rl_train_configs( @@ -35,7 +35,6 @@ def build_rl_train_configs( scale_learning_rate_by_reward_std_dev: bool | None = None, logprob_calculation_chunk_size: int | None = None, packed_sequence_length: int | None = None, - megatron_topology: MegatronTopologyConfig | dict[str, int | None] | None = None, num_trajectories_learning_rate_multiplier_power: float | None = None, kl_ref_adapter_path: str | None = None, ) -> tuple[TrainConfig, dev.TrainConfig]: @@ -69,10 +68,6 @@ def build_rl_train_configs( dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size if packed_sequence_length is not None: dev_config["packed_sequence_length"] = packed_sequence_length - if megatron_topology is not None: - dev_config["megatron_topology"] = MegatronTopologyConfig.model_validate( - megatron_topology - ).model_dump(mode="json") if num_trajectories_learning_rate_multiplier_power is not None: dev_config["num_trajectories_learning_rate_multiplier_power"] = ( num_trajectories_learning_rate_multiplier_power diff --git a/src/art/auto_trajectory.py b/src/art/auto_trajectory.py index 0b0860808..434bcc403 100644 --- a/src/art/auto_trajectory.py +++ b/src/art/auto_trajectory.py @@ -9,6 +9,7 @@ from .openai import init_chat_completion, update_chat_completion from .preprocessing.moe_routing import attach_moe_routing_metadata_to_choice +from .preprocessing.vllm_tokens import attach_vllm_token_metadata_to_choice from .trajectories import History, Trajectory logger = logging.getLogger(__name__) @@ -105,9 +106,25 @@ def handle_httpx_response(self, response: httpx._models.Response) -> None: # Parse SSE content directly from buffered bytes chat_completion = parse_sse_to_chat_completion(content) choice = chat_completion.choices[0] + response_payload = chat_completion.model_dump(mode="python") + attach_vllm_token_metadata_to_choice( + choice=choice, + response_payload=response_payload, + choice_index=0, + ) + attach_moe_routing_metadata_to_choice( + choice=choice, + response_payload=response_payload, + choice_index=0, + ) else: response_payload = json.loads(content) choice = Choice(**response_payload["choices"][0]) + attach_vllm_token_metadata_to_choice( + choice=choice, + response_payload=response_payload, + choice_index=0, + ) attach_moe_routing_metadata_to_choice( choice=choice, response_payload=response_payload, diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index c48f2cbd3..57eda12de 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -8,6 +8,7 @@ LoRAConfig, TrainerArgs, ) +from .sequence_lengths import max_seq_length_from_model_config from .validate import is_dedicated_mode @@ -36,9 +37,14 @@ def get_model_config( else: enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True) + configured_init_args = config.get("init_args", {}) init_args = InitArgs( load_in_4bit=True, - max_seq_length=32768, + max_seq_length=max_seq_length_from_model_config( + base_model, + revision=configured_init_args.get("revision"), + token=configured_init_args.get("token"), + ), model_name=base_model, ) engine_args = EngineArgs( @@ -48,7 +54,7 @@ def get_model_config( model=base_model, ) engine_args.update(config.get("engine_args", {})) - init_args.update(config.get("init_args", {})) + init_args.update(configured_init_args) if last_checkpoint_dir := get_last_checkpoint_dir(output_dir): init_args["model_name"] = last_checkpoint_dir merged_lora_config = LoRAConfig( @@ -95,6 +101,4 @@ def get_model_config( result["trainer_gpu_ids"] = config["trainer_gpu_ids"] if "inference_gpu_ids" in config: result["inference_gpu_ids"] = config["inference_gpu_ids"] - if "megatron_topology" in config: - result["megatron_topology"] = config["megatron_topology"] return result diff --git a/src/art/dev/model.py b/src/art/dev/model.py index dc5624dbd..a042c2d47 100644 --- a/src/art/dev/model.py +++ b/src/art/dev/model.py @@ -1,13 +1,10 @@ from enum import Enum -from typing import TYPE_CHECKING, Literal, NoReturn +from typing import Literal, NoReturn from typing_extensions import Required, TypedDict from .engine import EngineArgs -if TYPE_CHECKING: - from ..types import MegatronTopologyConfig - RolloutWeightsMode = Literal["lora", "merged"] @@ -138,7 +135,6 @@ class InternalModelConfig(TypedDict, total=False): chat_template_content_format: vLLM chat template content format. chat_template_tool_schema_format: Tool schema rendering format used for local training tokenization. - megatron_topology: Fixed Megatron parallel topology for this model. allow_unvalidated_arch: Permit model-support validation workflows to run architectures that are not yet in the supported-model registry. """ @@ -156,7 +152,6 @@ class InternalModelConfig(TypedDict, total=False): chat_template_path: str chat_template_content_format: str chat_template_tool_schema_format: Literal["default", "vllm_openai"] - megatron_topology: "MegatronTopologyConfig | dict[str, int | None]" allow_unvalidated_arch: bool diff --git a/src/art/dev/sequence_lengths.py b/src/art/dev/sequence_lengths.py new file mode 100644 index 000000000..b0975ca6e --- /dev/null +++ b/src/art/dev/sequence_lengths.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from typing import Any, TypeGuard + +_MAX_SEQ_LENGTH_KEYS = ( + "max_position_embeddings", + "n_positions", + "seq_length", + "max_sequence_length", + "model_max_length", +) +_TEXT_CONFIG_KEYS = ("text_config", "llm_config", "language_config") + + +def _config_sections(config_dict: Mapping[str, Any]) -> Iterator[Mapping[str, Any]]: + for key in _TEXT_CONFIG_KEYS: + section = config_dict.get(key) + if isinstance(section, Mapping): + yield section + yield config_dict + + +def _valid_max_seq_length(value: object) -> TypeGuard[int]: + return isinstance(value, int) and 0 < value < 1_000_000_000 + + +def max_seq_length_from_model_config( + base_model: str, + *, + revision: str | None = None, + token: str | None = None, +) -> int: + from transformers import PretrainedConfig + + kwargs = { + key: value + for key, value in {"revision": revision, "token": token}.items() + if value is not None + } + config_dict, _ = PretrainedConfig.get_config_dict(base_model, **kwargs) + for section in _config_sections(config_dict): + for key in _MAX_SEQ_LENGTH_KEYS: + value = section.get(key) + if _valid_max_seq_length(value): + return int(value) + raise ValueError( + f"Could not infer max_seq_length from Hugging Face config for {base_model!r}. " + "Set init_args.max_seq_length explicitly." + ) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index aea05cae4..495125baa 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -30,10 +30,6 @@ class TrainConfig(TypedDict, total=False): logprob_calculation_chunk_size: int mask_prob_ratio: bool max_negative_advantage_importance_sampling_weight: float - megatron_topology: dict[ - Literal["tp", "cp", "ep", "pp", "vpp", "etp"], - int | None, - ] moe_routing_replay_bundle: "MoeRoutingReplayBundle | None" moe_routing_replay_path: str | None moe_routing_replay_strict: bool diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 431de179b..f277a3e4c 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -52,6 +52,7 @@ ) from ..backend import AnyTrainableModel, Backend from ..costs import build_cost_calculator, get_model_pricing +from ..dev.sequence_lengths import max_seq_length_from_model_config from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, build_training_summary_metrics, @@ -72,7 +73,6 @@ from ..trajectories import Trajectory, TrajectoryGroup from ..types import ( LocalTrainResult, - MegatronTopologyConfig, Message, TrainConfig, TrainSFTConfig, @@ -188,6 +188,9 @@ def __init__( self._services: dict[str, ModelService] = {} self._adapter_leases: dict[str, AdapterLeaseManager] = {} self._tokenizers: dict[tuple[str, str | None], PreTrainedTokenizerBase] = {} + self._model_max_sequence_lengths: dict[ + tuple[str, str | None, str | None], int + ] = {} self._image_processors: dict[str, BaseImageProcessor | None] = {} self._requires_explicit_packed_sequence_length = False self._packed_sequence_length_requires_chunk_alignment = True @@ -217,6 +220,27 @@ def _model_uses_expert_replay(self, model: AnyTrainableModel) -> bool: except UnsupportedModelArchitectureError: return False + def _model_max_sequence_length(self, model: AnyTrainableModel) -> int: + internal_config = cast(dev.InternalModelConfig, model._internal_config or {}) + configured = internal_config.get("init_args", {}).get("max_seq_length") + if configured is not None: + return int(configured) + init_args = internal_config.get("init_args", {}) + cache_key = ( + model.base_model, + init_args.get("revision"), + init_args.get("token"), + ) + if cache_key not in self._model_max_sequence_lengths: + self._model_max_sequence_lengths[cache_key] = ( + max_seq_length_from_model_config( + model.base_model, + revision=cache_key[1], + token=cache_key[2], + ) + ) + return self._model_max_sequence_lengths[cache_key] + def supports_automatic_train_step_metrics(self) -> bool: return True @@ -536,9 +560,28 @@ def _get_packed_tensors( ) if not tokenized_results: return None - model_max_sequence_length = internal_config.get("init_args", {}).get( - "max_seq_length", 32_768 - ) + model_max_sequence_length = self._model_max_sequence_length(model) + too_long_for_model = [ + result + for result in tokenized_results + if len(result.token_ids) > model_max_sequence_length + ] + if too_long_for_model: + warnings.warn( + "Dropping " + f"{len(too_long_for_model)} tokenized results from " + f"{len({id(result.trajectory) for result in too_long_for_model})} " + f"trajectories longer than model max_seq_length={model_max_sequence_length} " + f"(max seen {max(len(result.token_ids) for result in too_long_for_model)}).", + stacklevel=2, + ) + tokenized_results = [ + result + for result in tokenized_results + if len(result.token_ids) <= model_max_sequence_length + ] + if not tokenized_results: + return None if packed_sequence_length is None: assert not self._requires_explicit_packed_sequence_length, ( f"{type(self).__name__} requires packed_sequence_length to be set." @@ -551,11 +594,6 @@ def _get_packed_tensors( else: sequence_length = packed_sequence_length - if sequence_length > model_max_sequence_length: - raise ValueError( - f"packed_sequence_length ({sequence_length}) exceeds model max_seq_length " - f"({model_max_sequence_length})" - ) if ( packed_sequence_length is not None and self._packed_sequence_length_requires_chunk_alignment @@ -576,7 +614,7 @@ def _get_packed_tensors( "Dropping " f"{len(too_long_results)} tokenized results from " f"{len({id(result.trajectory) for result in too_long_results})} " - f"trajectories longer than packed_sequence_length={sequence_length} " + f"trajectories that do not fit packed_sequence_length={sequence_length} " f"(max seen {max(len(result.token_ids) for result in too_long_results)}). " "This affects training, but your model may still learn.", stacklevel=2, @@ -656,7 +694,6 @@ async def _prepare_backend_for_training( if self._model_uses_expert_replay(model): engine_args = dict(config_dict.get("engine_args", {})) engine_args["enable_return_routed_experts"] = True - engine_args["async_scheduling"] = False config_dict["engine_args"] = engine_args server_args = dict(config_dict.get("server_args", {})) @@ -727,7 +764,6 @@ async def train( # type: ignore[override] scale_learning_rate_by_reward_std_dev: bool = False, logprob_calculation_chunk_size: int = 1024, packed_sequence_length: int | None = None, - megatron_topology: MegatronTopologyConfig | None = None, num_trajectories_learning_rate_multiplier_power: float = 0.0, # Checkpoint behavior save_checkpoint: bool = True, @@ -793,9 +829,6 @@ async def train( # type: ignore[override] packed_sequence_length: Packed sequence length to use for training. When unset, Unsloth keeps the current max-length-rounded-to-2048 behavior. Required for Megatron. - megatron_topology: Parallel topology for Megatron training. When - provided, ART uses it to configure Megatron TP/CP/EP/PP/VPP/ETP - before launching the Megatron runtime. num_trajectories_learning_rate_multiplier_power: Power for learning rate multiplier based on number of trajectories. save_checkpoint: Whether to save a checkpoint after training. @@ -867,7 +900,6 @@ async def train( # type: ignore[override] scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev, logprob_calculation_chunk_size=logprob_calculation_chunk_size, packed_sequence_length=packed_sequence_length, - megatron_topology=megatron_topology, num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power, kl_ref_adapter_path=resolved_kl_ref_adapter_path, ) @@ -1132,10 +1164,7 @@ async def _train_sft( print(f"Using instruction_part: {instruction_part!r}") print(f"Using response_part: {response_part!r}") - max_seq_length = internal_config.get("init_args", {}).get( - "max_seq_length", 32_768 - ) - max_seq_length = int(max_seq_length) if max_seq_length is not None else None + max_seq_length = self._model_max_sequence_length(model) import itertools from typing import Iterator diff --git a/src/art/megatron/backend.py b/src/art/megatron/backend.py index 14e5d2e31..9052303eb 100644 --- a/src/art/megatron/backend.py +++ b/src/art/megatron/backend.py @@ -1,9 +1,15 @@ +from typing import Any, Iterable + from mp_actors import move_to_child_process +from ..backend import AnyTrainableModel from ..local.backend import LocalBackend from ..local.service import ModelService from ..model import TrainableModel +from ..trajectories import TrajectoryGroup +from ..types import LocalTrainResult from ..utils.output_dirs import get_model_dir +from .runtime_config import get_megatron_runtime_config class MegatronBackend(LocalBackend): @@ -23,6 +29,25 @@ def __init__( self._packed_sequence_length_requires_chunk_alignment = False self._supports_result_packing = True + async def train( + self, + model: AnyTrainableModel, + trajectory_groups: Iterable[TrajectoryGroup], + **kwargs: Any, + ) -> LocalTrainResult: + for removed_kwarg in ("packed_sequence_length", "megatron_topology"): + if removed_kwarg in kwargs: + raise TypeError( + f"MegatronBackend.train gets {removed_kwarg} from " + "art.init_megatron_runtime_config(...)." + ) + return await super().train( + model, + trajectory_groups, + packed_sequence_length=get_megatron_runtime_config().packed_sequence_length, + **kwargs, + ) + async def _get_service(self, model: TrainableModel) -> ModelService: from ..dev.get_model_config import get_model_config from .service import MegatronService diff --git a/src/art/megatron/context_parallel/block_mask.py b/src/art/megatron/context_parallel/block_mask.py index 91fe2023b..f86f63320 100644 --- a/src/art/megatron/context_parallel/block_mask.py +++ b/src/art/megatron/context_parallel/block_mask.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import cast + import numpy as np import torch from torch.nn.attention.flex_attention import BlockMask @@ -11,6 +13,7 @@ _INVALID_Q_GROUP = -(1 << 63) _INVALID_Q_PARENT = _INVALID_Q_GROUP + 1 _INVALID_K_GROUP = _INVALID_Q_GROUP + 2 +_INVALID_POSITION = _INVALID_Q_GROUP + 3 def _build_exact_mask_mod( @@ -20,14 +23,40 @@ def _build_exact_mask_mod( q_group: np.ndarray, q_parent: np.ndarray, k_group: np.ndarray, + q_pos: np.ndarray | None, + k_pos: np.ndarray | None, + sliding_window: int | None, device: torch.device, ): - q_abs_tensor = torch.as_tensor(q_abs, device=device, dtype=torch.int64) - k_abs_tensor = torch.as_tensor(k_abs, device=device, dtype=torch.int64) q_group_tensor = torch.as_tensor(q_group, device=device, dtype=torch.int64) q_parent_tensor = torch.as_tensor(q_parent, device=device, dtype=torch.int64) k_group_tensor = torch.as_tensor(k_group, device=device, dtype=torch.int64) + if sliding_window is not None: + q_pos_tensor = torch.as_tensor(q_pos, device=device, dtype=torch.int64) + k_pos_tensor = torch.as_tensor(k_pos, device=device, dtype=torch.int64) + + def sliding_mask_mod( + batch_idx: torch.Tensor, + head_idx: torch.Tensor, + query_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + del batch_idx, head_idx + same_group = q_group_tensor[query_idx] == k_group_tensor[kv_idx] + parent_prefix = q_parent_tensor[query_idx] == k_group_tensor[kv_idx] + delta = q_pos_tensor[query_idx] - k_pos_tensor[kv_idx] + return ( + (same_group | parent_prefix) + & (delta >= 0) + & (delta < int(sliding_window)) + ) + + return sliding_mask_mod + + q_abs_tensor = torch.as_tensor(q_abs, device=device, dtype=torch.int64) + k_abs_tensor = torch.as_tensor(k_abs, device=device, dtype=torch.int64) + def mask_mod( batch_idx: torch.Tensor, head_idx: torch.Tensor, @@ -78,23 +107,29 @@ def _build_q_block_group_state( q_group: np.ndarray, q_parent: np.ndarray, q_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], frozenset[int]]: - start = int(block_idx) * q_block - end = min((int(block_idx) + 1) * q_block, int(q_abs.size)) - q = q_abs[start:end] - q_group_block = q_group[start:end] - q_parent_block = q_parent[start:end] - q_min = int(q.min()) if int(q.size) else 0 - max_by_group: dict[int, int] = {} - all_groups: list[int] = [] - for group_value in np.unique(np.concatenate((q_group_block, q_parent_block))): - allowed = (q_group_block == group_value) | (q_parent_block == group_value) - if bool(allowed.any()): - max_by_group[int(group_value)] = int(q[allowed].max()) - if bool(allowed.all()): - all_groups.append(int(group_value)) - return q_min, max_by_group, frozenset(all_groups) + q_blocks: int, +) -> tuple[np.ndarray, list[dict[int, int]], list[frozenset[int]]]: + q_min_by_block = np.empty((q_blocks,), dtype=np.int64) + q_allowed_max_by_group: list[dict[int, int]] = [] + q_all_allowed_groups: list[frozenset[int]] = [] + for block_idx in range(q_blocks): + start = block_idx * q_block + end = min((block_idx + 1) * q_block, int(q_abs.size)) + q = q_abs[start:end] + q_group_block = q_group[start:end] + q_parent_block = q_parent[start:end] + q_min_by_block[block_idx] = int(q.min()) if int(q.size) else 0 + max_by_group: dict[int, int] = {} + all_groups: list[int] = [] + for group_value in np.unique(np.concatenate((q_group_block, q_parent_block))): + allowed = (q_group_block == group_value) | (q_parent_block == group_value) + if bool(allowed.any()): + max_by_group[int(group_value)] = int(q[allowed].max()) + if bool(allowed.all()): + all_groups.append(int(group_value)) + q_allowed_max_by_group.append(max_by_group) + q_all_allowed_groups.append(frozenset(all_groups)) + return q_min_by_block, q_allowed_max_by_group, q_all_allowed_groups def _build_k_block_group_state( @@ -102,34 +137,97 @@ def _build_k_block_group_state( k_abs: np.ndarray, k_group: np.ndarray, k_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], tuple[int, ...]]: - start = int(block_idx) * k_block - end = min((int(block_idx) + 1) * k_block, int(k_abs.size)) - k = k_abs[start:end] - k_group_block = k_group[start:end] - k_max = int(k.max()) if int(k.size) else 0 - min_by_group: dict[int, int] = {} - for group_value in np.unique(k_group_block): - min_by_group[int(group_value)] = int(k[k_group_block == group_value].min()) - return k_max, min_by_group, tuple(min_by_group) + k_blocks: int, +) -> tuple[np.ndarray, list[dict[int, int]], list[tuple[int, ...]]]: + k_max_by_block = np.empty((k_blocks,), dtype=np.int64) + k_min_by_group: list[dict[int, int]] = [] + k_groups_by_block: list[tuple[int, ...]] = [] + for block_idx in range(k_blocks): + start = block_idx * k_block + end = min((block_idx + 1) * k_block, int(k_abs.size)) + k = k_abs[start:end] + k_group_block = k_group[start:end] + k_max_by_block[block_idx] = int(k.max()) if int(k.size) else 0 + min_by_group: dict[int, int] = {} + for group_value in np.unique(k_group_block): + min_by_group[int(group_value)] = int(k[k_group_block == group_value].min()) + k_min_by_group.append(min_by_group) + k_groups_by_block.append(tuple(min_by_group)) + return k_max_by_block, k_min_by_group, k_groups_by_block def _exact_block_state( *, - q_state: tuple[int, dict[int, int], frozenset[int]], - k_state: tuple[int, dict[int, int], tuple[int, ...]], + q_idx: int, + k_idx: int, + q_min_by_block: np.ndarray, + q_allowed_max_by_group: list[dict[int, int]], + q_all_allowed_groups: list[frozenset[int]], + k_max_by_block: np.ndarray, + k_min_by_group: list[dict[int, int]], + k_groups_by_block: list[tuple[int, ...]], ) -> tuple[bool, bool]: - q_min, q_allowed_max, q_all_allowed = q_state - k_max, k_min, k_groups = k_state + q_allowed_max = q_allowed_max_by_group[q_idx] + k_min = k_min_by_group[k_idx] if not any( q_allowed_max.get(k_group_value, _INVALID_Q_GROUP) >= min_k for k_group_value, min_k in k_min.items() ): return False, False - if int(q_min) < int(k_max): + if int(q_min_by_block[q_idx]) < int(k_max_by_block[k_idx]): return True, False - return True, all(k_group_value in q_all_allowed for k_group_value in k_groups) + q_all_allowed = q_all_allowed_groups[q_idx] + return True, all( + k_group_value in q_all_allowed for k_group_value in k_groups_by_block[k_idx] + ) + + +def _range_min_max( + values: np.ndarray, + starts: np.ndarray, + ends: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + mins = np.empty((int(starts.size),), dtype=np.int64) + maxes = np.empty((int(starts.size),), dtype=np.int64) + for index, (start, end) in enumerate(zip(starts, ends, strict=True)): + block = values[int(start) : int(end)] + mins[index] = int(block.min()) if int(block.size) else 0 + maxes[index] = int(block.max()) if int(block.size) else 0 + return mins, maxes + + +def _exact_sliding_block_state( + *, + q_idx: int, + k_idx: int, + q_block: int, + k_block: int, + q_len: int, + k_len: int, + q_group: np.ndarray, + q_parent: np.ndarray, + k_group: np.ndarray, + q_pos: np.ndarray, + k_pos: np.ndarray, + sliding_window: int, +) -> tuple[bool, bool]: + q_start = int(q_idx) * int(q_block) + q_end = min(q_start + int(q_block), int(q_len)) + k_start = int(k_idx) * int(k_block) + k_end = min(k_start + int(k_block), int(k_len)) + q_group_block = q_group[q_start:q_end] + q_parent_block = q_parent[q_start:q_end] + k_group_block = k_group[k_start:k_end] + delta = q_pos[q_start:q_end, None] - k_pos[None, k_start:k_end] + allowed = ( + ( + (q_group_block[:, None] == k_group_block[None, :]) + | (q_parent_block[:, None] == k_group_block[None, :]) + ) + & (delta >= 0) + & (delta < int(sliding_window)) + ) + return bool(allowed.any()), bool(allowed.all()) def _build_sparse_block_mask( @@ -138,6 +236,8 @@ def _build_sparse_block_mask( device: torch.device, group_ids: torch.Tensor, parent_ids: torch.Tensor, + input_pos: torch.Tensor | None, + sliding_window: int | None, block_size: tuple[int, int], ) -> BlockMask: q_block, k_block = block_size @@ -162,6 +262,15 @@ def _build_sparse_block_mask( ) flat_group_ids_np = flat_group_ids.numpy() flat_parent_ids_np = flat_parent_ids.numpy() + input_pos_for_window = cast(torch.Tensor, input_pos) + flat_input_pos_np = ( + input_pos_for_window.detach() + .to(device="cpu", dtype=torch.int64) + .reshape(-1) + .numpy() + if sliding_window is not None + else None + ) q_group = _select_with_invalid_np( flat_group_ids_np, q_abs, @@ -177,14 +286,50 @@ def _build_sparse_block_mask( k_abs, invalid_value=_INVALID_K_GROUP, ) + q_pos = ( + _select_with_invalid_np( + cast(np.ndarray, flat_input_pos_np), + q_abs, + invalid_value=_INVALID_POSITION, + ) + if sliding_window is not None + else None + ) + k_pos = ( + _select_with_invalid_np( + cast(np.ndarray, flat_input_pos_np), + k_abs, + invalid_value=_INVALID_POSITION, + ) + if sliding_window is not None + else None + ) mask_mod = _build_exact_mask_mod( q_abs=q_abs, k_abs=k_abs, q_group=q_group, q_parent=q_parent, k_group=k_group, + q_pos=q_pos, + k_pos=k_pos, + sliding_window=sliding_window, device=device, ) + q_min_by_block, q_allowed_max_by_group, q_all_allowed_groups = ( + _build_q_block_group_state( + q_abs=q_abs, + q_group=q_group, + q_parent=q_parent, + q_block=q_block, + q_blocks=q_blocks, + ) + ) + k_max_by_block, k_min_by_group, k_groups_by_block = _build_k_block_group_state( + k_abs=k_abs, + k_group=k_group, + k_block=k_block, + k_blocks=k_blocks, + ) if not spec.slices: raise RuntimeError( "Cannot build a CP attention block mask without stage slices" @@ -237,10 +382,32 @@ def _build_sparse_block_mask( q_max = q_abs[q_overlap_end - 1] k_min = k_abs[k_overlap_start] k_max = k_abs[k_overlap_end - 1] + if sliding_window is not None: + q_pos_for_window = cast(np.ndarray, q_pos) + k_pos_for_window = cast(np.ndarray, k_pos) + q_pos_min, q_pos_max = _range_min_max( + q_pos_for_window, + q_overlap_start, + q_overlap_end, + ) + k_pos_min, k_pos_max = _range_min_max( + k_pos_for_window, + k_overlap_start, + k_overlap_end, + ) + window_has_any = (q_pos_max[:, None] >= k_pos_min[None, :]) & ( + q_pos_min[:, None] - k_pos_max[None, :] < int(sliding_window) + ) + window_is_full = (q_pos_min[:, None] >= k_pos_max[None, :]) & ( + q_pos_max[:, None] - k_pos_min[None, :] < int(sliding_window) + ) q_is_full = (q_overlap_start == q_block_start) & (q_overlap_end == q_block_end) k_is_full = (k_overlap_start == k_block_start) & (k_overlap_end == k_block_end) covers_block = q_is_full[:, None] & k_is_full[None, :] - if slice_.mask_kind == AttnMaskKind.FULL: + if sliding_window is not None: + has_any = window_has_any + is_full = covers_block & window_is_full + elif slice_.mask_kind == AttnMaskKind.FULL: has_any = np.ones( (int(q_block_indices.size), int(k_block_indices.size)), dtype=bool ) @@ -256,32 +423,33 @@ def _build_sparse_block_mask( full_blocks[q_slice, k_slice] |= is_full ambiguous = (touch_counts > 1) & partial_blocks & ~full_blocks - q_state_cache: dict[int, tuple[int, dict[int, int], frozenset[int]]] = {} - k_state_cache: dict[int, tuple[int, dict[int, int], tuple[int, ...]]] = {} for q_idx, k_idx in np.argwhere(ambiguous): - q_state = q_state_cache.get(int(q_idx)) - if q_state is None: - q_state = _build_q_block_group_state( - q_abs=q_abs, + if sliding_window is None: + has_any, is_full = _exact_block_state( + q_idx=int(q_idx), + k_idx=int(k_idx), + q_min_by_block=q_min_by_block, + q_allowed_max_by_group=q_allowed_max_by_group, + q_all_allowed_groups=q_all_allowed_groups, + k_max_by_block=k_max_by_block, + k_min_by_group=k_min_by_group, + k_groups_by_block=k_groups_by_block, + ) + else: + has_any, is_full = _exact_sliding_block_state( + q_idx=int(q_idx), + k_idx=int(k_idx), + q_block=q_block, + k_block=k_block, + q_len=int(spec.q_len), + k_len=int(spec.k_len), q_group=q_group, q_parent=q_parent, - q_block=q_block, - block_idx=int(q_idx), - ) - q_state_cache[int(q_idx)] = q_state - k_state = k_state_cache.get(int(k_idx)) - if k_state is None: - k_state = _build_k_block_group_state( - k_abs=k_abs, k_group=k_group, - k_block=k_block, - block_idx=int(k_idx), + q_pos=cast(np.ndarray, q_pos), + k_pos=cast(np.ndarray, k_pos), + sliding_window=int(sliding_window), ) - k_state_cache[int(k_idx)] = k_state - has_any, is_full = _exact_block_state( - q_state=q_state, - k_state=k_state, - ) partial_blocks[q_idx, k_idx] = False full_blocks[q_idx, k_idx] = False if is_full: @@ -298,117 +466,24 @@ def _build_sparse_block_mask( full_blocks, device=device, ) - q_num_blocks, q_indices = _dense_blocks_to_ordered( - partial_blocks.T, - device=device, - ) - full_q_num_blocks, full_q_indices = _dense_blocks_to_ordered( - full_blocks.T, - device=device, - ) - return BlockMask( - seq_lengths=(int(spec.q_len), int(spec.k_len)), - kv_num_blocks=kv_num_blocks, - kv_indices=kv_indices, - full_kv_num_blocks=full_kv_num_blocks, - full_kv_indices=full_kv_indices, - q_num_blocks=q_num_blocks, - q_indices=q_indices, - full_q_num_blocks=full_q_num_blocks, - full_q_indices=full_q_indices, + return BlockMask.from_kv_blocks( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, BLOCK_SIZE=block_size, mask_mod=mask_mod, + seq_lengths=(int(spec.q_len), int(spec.k_len)), ) -def _valid_prefix(indices: torch.Tensor, *, name: str) -> torch.Tensor: - if indices.ndim != 1: - raise RuntimeError(f"{name} exact token indices must be rank 1.") - if indices.dtype != torch.int64: - raise RuntimeError(f"{name} exact token indices must be int64.") - indices_cpu = indices.detach().to(device="cpu", dtype=torch.int64).contiguous() - invalid = indices_cpu < 0 - if bool(invalid.any().item()): - first_invalid = int(torch.nonzero(invalid, as_tuple=False)[0].item()) - if bool((indices_cpu[first_invalid:] >= 0).any().item()): - raise RuntimeError( - f"{name} exact token indices must use only contiguous tail padding." - ) - return indices_cpu[:first_invalid] - return indices_cpu - - -def _validate_exact_indices( - indices: torch.Tensor, - *, - name: str, - source_len: int, -) -> int: - valid = _valid_prefix(indices, name=name) - if int(valid.numel()) == 0: - return 0 - if bool((valid[1:] <= valid[:-1]).any().item()): - raise RuntimeError(f"{name} exact token indices must be strictly increasing.") - max_index = int(valid[-1].item()) - if max_index >= int(source_len): - raise RuntimeError( - f"{name} exact token index {max_index} exceeds source metadata length {int(source_len)}." - ) - return int(valid.numel()) - - -def _validate_supported_mask_spec( - spec: FlexMaskSpec, - *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, -) -> None: - if group_ids.ndim != 1 or parent_ids.ndim != 1: - raise RuntimeError( - "Shared-prefix sparse block masks require rank-1 group_ids and parent_ids." - ) - if int(group_ids.numel()) != int(parent_ids.numel()): - raise RuntimeError( - "Shared-prefix sparse block masks require equal group_ids and parent_ids lengths." - ) - q_valid_len = _validate_exact_indices( - spec.exact_mask.q_token_indices, - name="q", - source_len=int(group_ids.numel()), - ) - k_valid_len = _validate_exact_indices( - spec.exact_mask.k_token_indices, - name="k", - source_len=int(group_ids.numel()), - ) - for slice_ in spec.slices: - if int(slice_.row_index) != 0: - raise RuntimeError( - "Shared-prefix sparse block masks support exactly one packed row." - ) - if slice_.mask_kind not in {AttnMaskKind.FULL, AttnMaskKind.CAUSAL}: - raise RuntimeError(f"Unsupported attention mask kind: {slice_.mask_kind}") - if ( - slice_.q_range.start < 0 - or slice_.q_range.end > int(spec.q_len) - or slice_.k_range.start < 0 - or slice_.k_range.end > int(spec.k_len) - or slice_.q_range.end < slice_.q_range.start - or slice_.k_range.end < slice_.k_range.start - ): - raise RuntimeError(f"Attention slice is outside mask bounds: {slice_}") - if slice_.q_range.end > q_valid_len or slice_.k_range.end > k_valid_len: - raise RuntimeError( - "Attention slices may not cover exact-index tail padding: " - f"slice={slice_}, q_valid_len={q_valid_len}, k_valid_len={k_valid_len}" - ) - - def build_block_mask( spec: FlexMaskSpec, *, group_ids: torch.Tensor, parent_ids: torch.Tensor, + input_pos: torch.Tensor | None = None, + sliding_window: int | None = None, device: torch.device, ) -> BlockMask | None: if spec.q_len <= 0 or spec.k_len <= 0: @@ -423,12 +498,13 @@ def build_block_mask( "Exact stage k-token metadata length mismatch: " f"{int(spec.exact_mask.k_token_indices.numel())} != {int(spec.k_len)}" ) - _validate_supported_mask_spec(spec, group_ids=group_ids, parent_ids=parent_ids) block_size = normalize_sparse_block_size(spec.block_size) return _build_sparse_block_mask( spec, device=device, group_ids=group_ids, parent_ids=parent_ids, + input_pos=input_pos, + sliding_window=sliding_window, block_size=block_size, ) diff --git a/src/art/megatron/context_parallel/core_attention.py b/src/art/megatron/context_parallel/core_attention.py index 8944878b7..6c986d72a 100644 --- a/src/art/megatron/context_parallel/core_attention.py +++ b/src/art/megatron/context_parallel/core_attention.py @@ -34,13 +34,8 @@ def __init__( pg_collection: ProcessGroupCollection | None = None, ): super().__init__() - del ( - layer_number, - attn_mask_type, - attention_type, - attention_dropout, - cp_comm_type, - ) + del attn_mask_type, attention_type, attention_dropout, cp_comm_type + self.layer_number = int(layer_number) self.config = config self.dense_kernel = FlexAttentionWrapper() @@ -99,10 +94,13 @@ def forward( enable_gqa=self.num_attention_heads_per_partition != self.num_query_groups_per_partition, compile_enabled=True, + sliding_window=getattr(self, "art_sliding_window", None), ) else: if isinstance(attention_bias, SharedPrefixAttentionState): - block_mask = attention_bias.block_mask + block_mask = attention_bias.block_mask_for_window( + getattr(self, "art_sliding_window", None) + ) else: assert isinstance(attention_bias, BlockMask), ( "Expected ArtContextParallelState, SharedPrefixAttentionState, or BlockMask in attention_bias." diff --git a/src/art/megatron/context_parallel/executor.py b/src/art/megatron/context_parallel/executor.py index e5e219e72..7994b0e8a 100644 --- a/src/art/megatron/context_parallel/executor.py +++ b/src/art/megatron/context_parallel/executor.py @@ -12,6 +12,7 @@ from art.megatron.flex_attn.compiled import ( SparseBlockSize, flash_sparse_block_size_for_head_dim, + flex_backend_for_head_dims, get_sparse_compiled_flex_attention, normalize_flex_lse, normalize_sparse_block_size, @@ -29,6 +30,7 @@ from .types import ( ArtContextParallelState, AttnSlice, + CpBlockMaskVariant, DkvReducePlan, ExactMaskMetadata, FlexMaskSpec, @@ -57,10 +59,6 @@ def _stage_sparse_block_size( ) -def _execution_sparse_block_size(state: ArtContextParallelState) -> SparseBlockSize: - return state.config.attention_sparse_block_size or state.config.block_size - - def _pad_exact_indices(indices: torch.Tensor, target_len: int) -> torch.Tensor: current_len = int(indices.numel()) target_len = int(target_len) @@ -612,6 +610,10 @@ def run( raise RuntimeError( "ART context parallel attention requires a concrete block mask for compiled flex attention." ) + backend = flex_backend_for_head_dims( + head_dim=int(q.shape[-1]), + head_dim_v=int(v.shape[-1]), + ) if compile_key is None: _q_len, _k_len, compile_key = select_sparse_execution_family( is_local_stage=bool(is_local_stage), @@ -621,9 +623,10 @@ def run( ) compiled_flex_attention = ( sparse_compiled_flex_attention - if str(compile_key) == "sparse" + if str(compile_key) == "sparse" and backend == "FLASH" else get_sparse_compiled_flex_attention( family_key=str(compile_key), + backend=backend, ) ) out, aux = cast( @@ -641,7 +644,7 @@ def run( lse = aux.lse if lse is None: raise RuntimeError("Compiled flex attention did not return lse.") - lse = normalize_flex_lse(lse) + lse = normalize_flex_lse(lse, backend=backend) return out, lse @@ -652,10 +655,11 @@ def _build_stage_block_mask( device: torch.device, execution_spec: StageExecutionSpec | None = None, block_size: SparseBlockSize | None = None, + sliding_window: int | None = None, ) -> BlockMask | None: - resolved_block_size = normalize_sparse_block_size( - _execution_sparse_block_size(state) if block_size is None else block_size - ) + if block_size is None: + block_size = state.config.attention_sparse_block_size or state.config.block_size + resolved_block_size = normalize_sparse_block_size(block_size) execution_spec = ( _resolve_stage_execution_spec( stage_plan=stage_plan, @@ -669,6 +673,7 @@ def _build_stage_block_mask( stage_plan=stage_plan, execution_spec=execution_spec, block_size=resolved_block_size, + sliding_window=sliding_window, device=device, ) cache = state.execution_cache.block_masks @@ -694,6 +699,8 @@ def _build_stage_block_mask( ), group_ids=state.group_ids, parent_ids=state.parent_ids, + input_pos=state.input_pos, + sliding_window=sliding_window, device=device, ) cache[cache_key] = mask @@ -707,12 +714,13 @@ def _get_prepared_stage_block_mask( device: torch.device, execution_spec: StageExecutionSpec, block_size: SparseBlockSize, + sliding_window: int | None, ) -> BlockMask: - resolved_block_size = normalize_sparse_block_size(block_size) cache_key = _stage_block_mask_cache_key( stage_plan=stage_plan, execution_spec=execution_spec, - block_size=resolved_block_size, + block_size=normalize_sparse_block_size(block_size), + sliding_window=sliding_window, device=device, ) cache = state.execution_cache.block_masks @@ -721,15 +729,16 @@ def _get_prepared_stage_block_mask( "ART context parallel forward hit an unprepared stage block-mask cache key. " "Mask construction is CPU planning work and must finish before model forward. " f"stage={int(stage_plan.stage_index)} q_len={int(execution_spec.q_len)} " - f"k_len={int(execution_spec.k_len)} block_size={resolved_block_size} " - f"device={device}" + f"k_len={int(execution_spec.k_len)} " + f"block_size={normalize_sparse_block_size(block_size)} " + f"sliding_window={sliding_window} device={device}" ) block_mask = cache[cache_key] if block_mask is None: raise RuntimeError( "ART context parallel forward found an empty prepared block mask for a non-empty stage. " f"stage={int(stage_plan.stage_index)} q_len={int(execution_spec.q_len)} " - f"k_len={int(execution_spec.k_len)}" + f"k_len={int(execution_spec.k_len)} sliding_window={sliding_window}" ) return cast(BlockMask, block_mask) @@ -739,13 +748,15 @@ def _stage_block_mask_cache_key( stage_plan: StagePlan, execution_spec: StageExecutionSpec, block_size: tuple[int, int], + sliding_window: int | None, device: torch.device, -) -> tuple[int, int, int, tuple[int, int], str, int | None]: +) -> tuple[int, int, int, tuple[int, int], int | None, str, int | None]: return ( int(stage_plan.stage_index), int(execution_spec.q_len), int(execution_spec.k_len), block_size, + None if sliding_window is None else int(sliding_window), device.type, device.index, ) @@ -756,22 +767,29 @@ def prepare_context_parallel_execution_state( state: ArtContextParallelState, device: torch.device, ) -> None: - block_size = _execution_sparse_block_size(state) + variants = state.block_mask_variants or ( + CpBlockMaskVariant( + sliding_window=None, + block_size=normalize_sparse_block_size(state.config.block_size), + ), + ) for stage_plan in state.rank_plan.stage_plans: if stage_plan.q_len <= 0 or stage_plan.k_len <= 0 or not stage_plan.slices: continue - execution_spec = _resolve_stage_execution_spec( - stage_plan=stage_plan, - state=state, - block_size=block_size, - ) - _build_stage_block_mask( - stage_plan=stage_plan, - state=state, - device=device, - execution_spec=execution_spec, - block_size=block_size, - ) + for variant in variants: + execution_spec = _resolve_stage_execution_spec( + stage_plan=stage_plan, + state=state, + block_size=variant.block_size, + ) + _build_stage_block_mask( + stage_plan=stage_plan, + state=state, + device=device, + execution_spec=execution_spec, + block_size=variant.block_size, + sliding_window=variant.sliding_window, + ) def _causal_slice_pair_count(slice_: AttnSlice) -> int: @@ -852,63 +870,53 @@ def _resolve_stage_execution_spec( resolved_block_size = normalize_sparse_block_size( state.config.block_size if block_size is None else block_size ) - cache_key = _stage_execution_spec_cache_key( - stage_plan=stage_plan, - block_size=resolved_block_size, - ) - cache = state.execution_cache.stage_execution_specs + cache_key = (int(stage_plan.stage_index), resolved_block_size) + execution_cache = getattr(state, "execution_cache", None) + if execution_cache is None: + target_q_len, target_k_len, compile_key = select_sparse_execution_family( + is_local_stage=bool(stage_plan.is_local_stage), + q_len=int(stage_plan.q_len), + k_len=int(stage_plan.k_len), + block_size=resolved_block_size, + ) + return StageExecutionSpec( + q_len=int(target_q_len), + k_len=int(target_k_len), + compile_key=str(compile_key), + mask_metadata=_resize_exact_mask_metadata( + stage_plan.mask_metadata, + q_len=int(target_q_len), + k_len=int(target_k_len), + ), + ) + cache = getattr(execution_cache, "stage_execution_specs", None) + if cache is None: + target_q_len, target_k_len, compile_key = select_sparse_execution_family( + is_local_stage=bool(stage_plan.is_local_stage), + q_len=int(stage_plan.q_len), + k_len=int(stage_plan.k_len), + block_size=resolved_block_size, + ) + return StageExecutionSpec( + q_len=int(target_q_len), + k_len=int(target_k_len), + compile_key=str(compile_key), + mask_metadata=_resize_exact_mask_metadata( + stage_plan.mask_metadata, + q_len=int(target_q_len), + k_len=int(target_k_len), + ), + ) cached = cache.get(cache_key) if cached is not None: return cached - resolved = _build_stage_execution_spec( - stage_plan=stage_plan, - block_size=resolved_block_size, - ) - cache[cache_key] = resolved - return resolved - - -def _get_prepared_stage_execution_spec( - *, - stage_plan: StagePlan, - state: ArtContextParallelState, - block_size: SparseBlockSize, -) -> StageExecutionSpec: - resolved_block_size = normalize_sparse_block_size(block_size) - cache_key = _stage_execution_spec_cache_key( - stage_plan=stage_plan, - block_size=resolved_block_size, - ) - cached = state.execution_cache.stage_execution_specs.get(cache_key) - if cached is None: - raise RuntimeError( - "ART context parallel forward hit an unprepared stage execution-spec cache key. " - "Execution planning must finish before model forward. " - f"stage={int(stage_plan.stage_index)} block_size={resolved_block_size}" - ) - return cached - - -def _stage_execution_spec_cache_key( - *, - stage_plan: StagePlan, - block_size: tuple[int, int], -) -> tuple[int, tuple[int, int]]: - return (int(stage_plan.stage_index), block_size) - - -def _build_stage_execution_spec( - *, - stage_plan: StagePlan, - block_size: tuple[int, int], -) -> StageExecutionSpec: target_q_len, target_k_len, compile_key = select_sparse_execution_family( is_local_stage=bool(stage_plan.is_local_stage), q_len=int(stage_plan.q_len), k_len=int(stage_plan.k_len), - block_size=block_size, + block_size=resolved_block_size, ) - return StageExecutionSpec( + resolved = StageExecutionSpec( q_len=int(target_q_len), k_len=int(target_k_len), compile_key=str(compile_key), @@ -918,6 +926,8 @@ def _build_stage_execution_spec( k_len=int(target_k_len), ), ) + cache[cache_key] = resolved + return resolved def _run_stage_attention( @@ -930,9 +940,10 @@ def _run_stage_attention( kernel: FlexAttentionKernel, scale: float, enable_gqa: bool, + sliding_window: int | None, ) -> tuple[torch.Tensor, torch.Tensor]: sparse_block_size = _stage_sparse_block_size(q_stage, v_stage) - execution_spec = _get_prepared_stage_execution_spec( + execution_spec = _resolve_stage_execution_spec( stage_plan=stage_plan, state=state, block_size=sparse_block_size, @@ -943,6 +954,7 @@ def _run_stage_attention( device=q_stage.device, execution_spec=execution_spec, block_size=sparse_block_size, + sliding_window=sliding_window, ) _validate_stage_block_alignment( q_len=int(execution_spec.q_len), @@ -986,12 +998,16 @@ def _run_stage_attention( out_tokens = out.squeeze(0) lse_tokens = lse.squeeze(0).to(dtype=torch.float32) return ( - out_tokens[:, :logical_q_len] - if int(execution_spec.q_len) == logical_q_len - else out_tokens[:, :logical_q_len].contiguous(), - lse_tokens[:, :logical_q_len] - if int(execution_spec.q_len) == logical_q_len - else lse_tokens[:, :logical_q_len].contiguous(), + ( + out_tokens[:, :logical_q_len] + if int(execution_spec.q_len) == logical_q_len + else out_tokens[:, :logical_q_len].contiguous() + ), + ( + lse_tokens[:, :logical_q_len] + if int(execution_spec.q_len) == logical_q_len + else lse_tokens[:, :logical_q_len].contiguous() + ), ) out_tokens = out.squeeze(0).permute(1, 0, 2).contiguous() lse_tokens = lse.squeeze(0).permute(1, 0).contiguous().to(dtype=torch.float32) @@ -1491,6 +1507,7 @@ def _forward_stage_records( kernel: FlexAttentionKernel, scale: float, enable_gqa: bool, + sliding_window: int | None, record_for_backward: bool, ) -> tuple[torch.Tensor, list[dict[str, Any]]]: q_source = q_flat.detach() if record_for_backward else q_flat @@ -1585,6 +1602,7 @@ def _forward_stage_records( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, ) replay_records.append( { @@ -1606,6 +1624,7 @@ def _forward_stage_records( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, ) stage_out_value = stage_out.detach() if record_for_backward else stage_out stage_lse_value = stage_lse.detach() if record_for_backward else stage_lse @@ -1714,6 +1733,7 @@ def _forward_stage_records( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, ) replay_records.append( { @@ -1735,6 +1755,7 @@ def _forward_stage_records( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, ) stage_out_value = stage_out.detach() if record_for_backward else stage_out stage_lse_value = stage_lse.detach() if record_for_backward else stage_lse @@ -1799,9 +1820,10 @@ def _forward_stage_records( if not produced_output: if int(q_flat.shape[1]) == 0: - return q_flat.new_empty( - (q_flat.shape[0], 0, q_flat.shape[2]) - ), replay_records + return ( + q_flat.new_empty((q_flat.shape[0], 0, q_flat.shape[2])), + replay_records, + ) raise RuntimeError("Sparse attention produced no stage outputs") if accum_out is None: raise RuntimeError("Sparse attention produced no accumulated output") @@ -1831,6 +1853,7 @@ def _run_context_parallel_forward( scale: float, enable_gqa: bool, compile_enabled: bool, + sliding_window: int | None, ) -> torch.Tensor: kernel = FlexAttentionKernel(compile_enabled=compile_enabled) q_flat, k_flat, v_flat = _flatten_qkv( @@ -1847,6 +1870,7 @@ def _run_context_parallel_forward( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, record_for_backward=False, ) return unflatten_valid_sequence_head_major( @@ -1865,6 +1889,7 @@ def _run_context_parallel_forward_recorded( scale: float, enable_gqa: bool, compile_enabled: bool, + sliding_window: int | None, ) -> tuple[torch.Tensor, torch.Tensor, list[dict[str, Any]]]: kernel = FlexAttentionKernel(compile_enabled=compile_enabled) q_flat, k_flat, v_flat = _flatten_qkv( @@ -1881,6 +1906,7 @@ def _run_context_parallel_forward_recorded( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, record_for_backward=True, ) return ( @@ -1962,6 +1988,7 @@ def _run_context_parallel_backward( scale: float, enable_gqa: bool, compile_enabled: bool, + sliding_window: int | None, replay_records: list[dict[str, Any]] | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: kernel = FlexAttentionKernel(compile_enabled=compile_enabled) @@ -1985,6 +2012,7 @@ def _run_context_parallel_backward( kernel=kernel, scale=scale, enable_gqa=enable_gqa, + sliding_window=sliding_window, record_for_backward=True, ) stage_out_grads, stage_lse_grads = _merge_stage_output_grads_from_tape( @@ -2004,12 +2032,16 @@ def _run_context_parallel_backward( ): stage_plan = cast(StagePlan, record["stage_plan"]) grad_by_stage_index[int(stage_plan.stage_index)] = ( - _zero_stage_grads_like(record["stage_out"]) - if stage_out_grad is None - else stage_out_grad, - _zero_stage_grads_like(record["stage_lse"]) - if stage_lse_grad is None - else stage_lse_grad, + ( + _zero_stage_grads_like(record["stage_out"]) + if stage_out_grad is None + else stage_out_grad + ), + ( + _zero_stage_grads_like(record["stage_lse"]) + if stage_lse_grad is None + else stage_lse_grad + ), ) del stage_out_grads, stage_lse_grads @@ -2211,11 +2243,13 @@ def forward( scale: float, enable_gqa: bool, compile_enabled: bool, + sliding_window: int | None, ) -> torch.Tensor: ctx.state = state ctx.scale = float(scale) ctx.enable_gqa = bool(enable_gqa) ctx.compile_enabled = bool(compile_enabled) + ctx.sliding_window = sliding_window ctx.save_for_backward(query, key, value) with torch.enable_grad(): query_record = query.detach().requires_grad_(bool(query.requires_grad)) @@ -2230,6 +2264,7 @@ def forward( scale=float(scale), enable_gqa=bool(enable_gqa), compile_enabled=bool(compile_enabled), + sliding_window=sliding_window, ) ) ctx.replay_records = replay_records @@ -2249,11 +2284,12 @@ def backward(ctx, *grad_outputs: Any): scale=ctx.scale, enable_gqa=ctx.enable_gqa, compile_enabled=ctx.compile_enabled, + sliding_window=ctx.sliding_window, replay_records=cast(list[dict[str, Any]], ctx.replay_records), ) finally: ctx.replay_records = None - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None, None def run_context_parallel( @@ -2265,6 +2301,7 @@ def run_context_parallel( scale: float, enable_gqa: bool, compile_enabled: bool, + sliding_window: int | None = None, ) -> torch.Tensor: if torch.is_grad_enabled() and ( query.requires_grad or key.requires_grad or value.requires_grad @@ -2277,6 +2314,7 @@ def run_context_parallel( float(scale), bool(enable_gqa), bool(compile_enabled), + None if sliding_window is None else int(sliding_window), ) return _run_context_parallel_forward( query=query, @@ -2286,4 +2324,5 @@ def run_context_parallel( scale=scale, enable_gqa=enable_gqa, compile_enabled=compile_enabled, + sliding_window=sliding_window, ) diff --git a/src/art/megatron/context_parallel/runtime.py b/src/art/megatron/context_parallel/runtime.py index c6eb9fddd..98ee47c1d 100644 --- a/src/art/megatron/context_parallel/runtime.py +++ b/src/art/megatron/context_parallel/runtime.py @@ -21,6 +21,7 @@ ContextParallelConfig, ContextParallelRuntimeKey, ContextParallelRuntimePlan, + CpBlockMaskVariant, DispatchedPackedTensors, DkvReducePlan, ExactMaskMetadata, @@ -1166,11 +1167,13 @@ def _stage_cost_ms( pair_ms = ( config.planner_local_backward_pair_ms if backward and local - else config.planner_remote_backward_pair_ms - if backward - else config.planner_local_pair_ms - if local - else config.planner_remote_pair_ms + else ( + config.planner_remote_backward_pair_ms + if backward + else ( + config.planner_local_pair_ms if local else config.planner_remote_pair_ms + ) + ) ) remote_underfill_ms = 0.0 if not local and (pair_count > 0 or q_tokens > 0 or k_tokens > 0): @@ -1308,21 +1311,27 @@ def _evaluate_plan( empty_pair_counts = pair_counts.new_zeros((0, chunk_count)) empty_pair_positive = pair_positive.new_zeros((0, chunk_count)) pair_counts_by_rank_rows = [ - empty_pair_counts - if int(owner_index.numel()) == 0 - else pair_counts.index_select(0, owner_index) + ( + empty_pair_counts + if int(owner_index.numel()) == 0 + else pair_counts.index_select(0, owner_index) + ) for owner_index in owner_indices ] pair_positive_by_rank_rows = [ - empty_pair_positive - if int(owner_index.numel()) == 0 - else pair_positive.index_select(0, owner_index) + ( + empty_pair_positive + if int(owner_index.numel()) == 0 + else pair_positive.index_select(0, owner_index) + ) for owner_index in owner_indices ] pair_positive_by_rank_cols = [ - torch.zeros(chunk_count, dtype=torch.bool) - if int(rank_rows.numel()) == 0 - else rank_rows.any(dim=0) + ( + torch.zeros(chunk_count, dtype=torch.bool) + if int(rank_rows.numel()) == 0 + else rank_rows.any(dim=0) + ) for rank_rows in pair_positive_by_rank_rows ] wave_masks = [wave_tensor == wave_index for wave_index in range(wave_count)] @@ -1979,27 +1988,37 @@ def _build_rank_runtime_plan( for wave_index in range(wave_count): request_ranges_by_source = tuple( - _merge_ranges(recv_request_ranges[wave_index][source_rank]) - if source_rank != target_rank - else tuple() + ( + _merge_ranges(recv_request_ranges[wave_index][source_rank]) + if source_rank != target_rank + else tuple() + ) for source_rank in range(cp_size) ) send_global_ranges_by_peer = tuple( - _merge_ranges(send_request_ranges[wave_index][peer_rank]) - if peer_rank != target_rank - else tuple() + ( + _merge_ranges(send_request_ranges[wave_index][peer_rank]) + if peer_rank != target_rank + else tuple() + ) for peer_rank in range(cp_size) ) send_ranges_by_peer = tuple( - tuple(_remap_subrange(range_, host_local_ranges) for range_ in peer_ranges) - if peer_rank != target_rank - else tuple() + ( + tuple( + _remap_subrange(range_, host_local_ranges) for range_ in peer_ranges + ) + if peer_rank != target_rank + else tuple() + ) for peer_rank, peer_ranges in enumerate(send_global_ranges_by_peer) ) recv_splits = tuple( - _ranges_size(request_ranges_by_source[source_rank]) - if source_rank != target_rank - else 0 + ( + _ranges_size(request_ranges_by_source[source_rank]) + if source_rank != target_rank + else 0 + ) for source_rank in range(cp_size) ) send_splits = tuple( @@ -2141,6 +2160,7 @@ def prepare_cp_micro( build_gdn_execution_spec: bool = False, trace_token_uids: bool = False, prepare_execution_state: bool = True, + block_mask_variants: tuple[CpBlockMaskVariant, ...] = (), target_device: torch.device | None = None, ref_logprobs: torch.Tensor | None = None, ) -> PreparedMegatronBatch: @@ -2159,6 +2179,7 @@ def prepare_cp_micro( cp_group=cp_group, cp_rank=cp_rank, build_gdn_execution_spec=build_gdn_execution_spec, + block_mask_variants=block_mask_variants, target_device=target_device, ) tensors = dispatch_megatron_context_parallel_training_tensors( @@ -2197,6 +2218,7 @@ def prepare_megatron_context_parallel_state( cp_group: Any, cp_rank: int, build_gdn_execution_spec: bool = False, + block_mask_variants: tuple[CpBlockMaskVariant, ...] = (), target_device: torch.device | None = None, ) -> tuple[ArtContextParallelState, RankRuntimePlan, PackedBatchAttentionSpec, int]: """Build CP runtime state from CPU metadata. @@ -2222,6 +2244,7 @@ def prepare_megatron_context_parallel_state( ) group_ids_cpu = _planning_metadata_cpu(micro["group_ids"]) parent_ids_cpu = _planning_metadata_cpu(micro["parent_ids"]) + input_pos_cpu = _planning_metadata_cpu(micro["input_pos"]) runtime_config = _config_for_runtime_cp(topology=topology, config=config) planning_key = _planning_bundle_cache_key( group_ids=group_ids_cpu, @@ -2303,6 +2326,8 @@ def prepare_megatron_context_parallel_state( config=runtime_config, group_ids=group_ids_cpu[0].contiguous(), parent_ids=parent_ids_cpu[0].contiguous(), + input_pos=input_pos_cpu[0].contiguous(), + block_mask_variants=block_mask_variants, gdn_execution_spec=bundle.gdn_execution_spec, gdn_execution_plan=gdn_execution_plan, planner_provenance=planner_provenance, @@ -2452,12 +2477,16 @@ def dispatch_megatron_context_parallel_training_tensors( advantages=_to_target_device(local_advantages, target_device), weights=_to_target_device(local_weights, target_device), valid_lengths=rank_plan.local_valid_lengths, - original_logprobs=None - if local_original_logprobs is None - else _to_target_device(local_original_logprobs, target_device), - ref_logprobs=None - if local_ref_logprobs is None - else _to_target_device(local_ref_logprobs, target_device), + original_logprobs=( + None + if local_original_logprobs is None + else _to_target_device(local_original_logprobs, target_device) + ), + ref_logprobs=( + None + if local_ref_logprobs is None + else _to_target_device(local_ref_logprobs, target_device) + ), loss_all_reduce_group=cp_group, token_uids=None if local_token_uids is None else local_token_uids.contiguous(), ) @@ -2892,11 +2921,13 @@ def _dispatch_tensor( rank_plan: RankRuntimePlan, pad_value: int | float | bool, pad_multiple: int = 1, - dispatch_meta_cache: dict[ - tuple[tuple[tuple[int, int], ...], int, str, int | None], - tuple[torch.Tensor, torch.Tensor], - ] - | None = None, + dispatch_meta_cache: ( + dict[ + tuple[tuple[tuple[int, int], ...], int, str, int | None], + tuple[torch.Tensor, torch.Tensor], + ] + | None + ) = None, ) -> torch.Tensor: """Gather local rows without branching on CUDA tensor values. @@ -2940,11 +2971,13 @@ def _dispatch_meta( rank_plan: RankRuntimePlan, max_local_len: int, device: torch.device, - dispatch_meta_cache: dict[ - tuple[tuple[tuple[int, int], ...], int, str, int | None], - tuple[torch.Tensor, torch.Tensor], - ] - | None = None, + dispatch_meta_cache: ( + dict[ + tuple[tuple[tuple[int, int], ...], int, str, int | None], + tuple[torch.Tensor, torch.Tensor], + ] + | None + ) = None, ) -> tuple[torch.Tensor, torch.Tensor]: owner_ranges = tuple( range_ diff --git a/src/art/megatron/context_parallel/types.py b/src/art/megatron/context_parallel/types.py index 5cc874d09..e94cf5584 100644 --- a/src/art/megatron/context_parallel/types.py +++ b/src/art/megatron/context_parallel/types.py @@ -231,6 +231,13 @@ class ContextParallelExecutionCache(BaseModel): stage_execution_specs: dict[Any, "StageExecutionSpec"] = Field(default_factory=dict) +class CpBlockMaskVariant(BaseModel): + model_config = ConfigDict(frozen=True) + + sliding_window: int | None = None + block_size: tuple[int, int] + + class StageExecutionSpec(BaseModel): model_config = ConfigDict(frozen=True) @@ -266,6 +273,8 @@ class ArtContextParallelState(BaseModel): config: ContextParallelConfig group_ids: torch.Tensor parent_ids: torch.Tensor + input_pos: torch.Tensor + block_mask_variants: tuple[CpBlockMaskVariant, ...] = () gdn_execution_spec: Any | None = None gdn_execution_plan: Any | None = None gdn_hidden_layout: str = "attention" diff --git a/src/art/megatron/flex_attn/attention.py b/src/art/megatron/flex_attn/attention.py index 483742cab..6b3506699 100644 --- a/src/art/megatron/flex_attn/attention.py +++ b/src/art/megatron/flex_attn/attention.py @@ -8,12 +8,15 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import divide -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field import torch from torch import Tensor from torch.nn.attention.flex_attention import BlockMask -from art.megatron.flex_attn.compiled import dense_compiled_flex_attention +from art.megatron.flex_attn.compiled import ( + flex_backend_for_head_dims, + get_dense_compiled_flex_attention, +) class SharedPrefixAttentionState(BaseModel): @@ -21,6 +24,12 @@ class SharedPrefixAttentionState(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) block_mask: BlockMask + sliding_block_masks: dict[int, BlockMask] = Field(default_factory=dict) + + def block_mask_for_window(self, window: int | None) -> BlockMask: + if window is None: + return self.block_mask + return self.sliding_block_masks[int(window)] class FlexAttentionWrapper(torch.nn.Module): @@ -37,9 +46,13 @@ def forward( enable_gqa: bool, ) -> Tensor: # q, k, v are [B, H, S, D] tensors expected by torch.flex_attention. + backend = flex_backend_for_head_dims( + head_dim=int(q.shape[-1]), + head_dim_v=int(v.shape[-1]), + ) return cast( Tensor, - dense_compiled_flex_attention( + get_dense_compiled_flex_attention(backend=backend)( q, k, v, @@ -53,6 +66,9 @@ def forward( def create_shared_prefix_attention_state( group_ids: Tensor, parent_ids: Tensor, + *, + input_pos: Tensor | None = None, + sliding_windows: tuple[int, ...] = (), ) -> SharedPrefixAttentionState: """Build a compiled block mask for ART shared-prefix packing. @@ -65,7 +81,12 @@ def create_shared_prefix_attention_state( from art.megatron.shared_prefix_state import create_shared_prefix_state - return create_shared_prefix_state(group_ids, parent_ids) + return create_shared_prefix_state( + group_ids, + parent_ids, + input_pos=input_pos, + sliding_windows=sliding_windows, + ) class FlexDotProductAttention(torch.nn.Module): @@ -86,13 +107,8 @@ def __init__( pg_collection: ProcessGroupCollection | None = None, ): super().__init__() - del ( - layer_number, - attn_mask_type, - attention_type, - attention_dropout, - cp_comm_type, - ) + del attn_mask_type, attention_type, attention_dropout, cp_comm_type + self.layer_number = int(layer_number) self.config = config self.flex_attention = FlexAttentionWrapper() @@ -145,7 +161,9 @@ def forward( ) if isinstance(attention_bias, SharedPrefixAttentionState): - block_mask = attention_bias.block_mask + block_mask = attention_bias.block_mask_for_window( + getattr(self, "art_sliding_window", None) + ) else: assert isinstance(attention_bias, BlockMask), ( "Expected a flex BlockMask in attention_bias." diff --git a/src/art/megatron/flex_attn/compiled.py b/src/art/megatron/flex_attn/compiled.py index ad976754d..cf49d033c 100644 --- a/src/art/megatron/flex_attn/compiled.py +++ b/src/art/megatron/flex_attn/compiled.py @@ -1,7 +1,7 @@ """Compiled flex attention entrypoints.""" import math -from typing import Any, TypeAlias, cast +from typing import Any, Literal, TypeAlias, cast import torch from torch.nn.attention.flex_attention import ( @@ -19,15 +19,30 @@ # backend; production ART always uses FLASH here. _FORCED_FLEX_BACKEND = "FLASH" _FLASH_LSE_RESCALE = math.log(2.0) +FlexBackend: TypeAlias = Literal["FLASH", "TRITON"] SparseBlockSize: TypeAlias = int | tuple[int, int] -def normalize_flex_lse(lse: torch.Tensor) -> torch.Tensor: +def flex_backend_for_head_dims(*, head_dim: int, head_dim_v: int) -> FlexBackend: if _FORCED_FLEX_BACKEND != "FLASH": + return "TRITON" + if int(head_dim) > 256 or int(head_dim_v) > 256: + return "TRITON" + return "FLASH" + + +def normalize_flex_lse( + lse: torch.Tensor, + *, + backend: FlexBackend | None = None, +) -> torch.Tensor: + if (_FORCED_FLEX_BACKEND if backend is None else backend) != "FLASH": return lse return lse / _FLASH_LSE_RESCALE +_FLASH_FLEX_KERNEL_OPTIONS = cast(FlexKernelOptions, {"BACKEND": "FLASH"}) +_TRITON_FLEX_KERNEL_OPTIONS = cast(FlexKernelOptions, {"BACKEND": "TRITON"}) _FORCED_FLEX_KERNEL_OPTIONS = cast( FlexKernelOptions, {"BACKEND": _FORCED_FLEX_BACKEND}, @@ -49,7 +64,7 @@ def flash_sparse_block_size_for_head_dim( head_dim_v: int, device: torch.device, ) -> tuple[int, int]: - if _FORCED_FLEX_BACKEND != "FLASH": + if flex_backend_for_head_dims(head_dim=head_dim, head_dim_v=head_dim_v) != "FLASH": return (128, 128) if device.type != "cuda": return (128, 128) @@ -108,6 +123,31 @@ def _forced_flex_attention_sparse( ) +def _flex_attention_with_options(kernel_options: FlexKernelOptions) -> Any: + def _flex_attention( + q, + k, + v, + *, + block_mask, + scale, + enable_gqa, + return_aux: AuxRequest | None = None, + ): + return flex_attention( + q, + k, + v, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + kernel_options=kernel_options, + return_aux=return_aux, + ) + + return _flex_attention + + def select_sparse_execution_family( *, is_local_stage: bool, @@ -126,15 +166,43 @@ def select_sparse_execution_family( return int(target_q_len), int(target_k_len), "sparse" -def get_sparse_compiled_flex_attention(*, family_key: str) -> Any: +def get_dense_compiled_flex_attention(*, backend: FlexBackend) -> Any: + if backend == _FORCED_FLEX_BACKEND: + return dense_compiled_flex_attention + if backend == "FLASH": + return flash_dense_compiled_flex_attention + return triton_dense_compiled_flex_attention + + +def get_sparse_compiled_flex_attention( + *, + family_key: str, + backend: FlexBackend, +) -> Any: del family_key - return sparse_compiled_flex_attention + if backend == _FORCED_FLEX_BACKEND: + return sparse_compiled_flex_attention + if backend == "FLASH": + return flash_sparse_compiled_flex_attention + return triton_sparse_compiled_flex_attention dense_compiled_flex_attention = torch.compile( _forced_flex_attention_dense, ) +flash_dense_compiled_flex_attention = torch.compile( + _flex_attention_with_options(_FLASH_FLEX_KERNEL_OPTIONS), +) +triton_dense_compiled_flex_attention = torch.compile( + _flex_attention_with_options(_TRITON_FLEX_KERNEL_OPTIONS), +) sparse_compiled_flex_attention = torch.compile( _forced_flex_attention_sparse, ) +flash_sparse_compiled_flex_attention = torch.compile( + _flex_attention_with_options(_FLASH_FLEX_KERNEL_OPTIONS), +) +triton_sparse_compiled_flex_attention = torch.compile( + _flex_attention_with_options(_TRITON_FLEX_KERNEL_OPTIONS), +) diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py new file mode 100644 index 000000000..c26cfc212 --- /dev/null +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -0,0 +1,1777 @@ +from __future__ import annotations + +from contextlib import nullcontext +from copy import copy +from functools import lru_cache +import json +from pathlib import Path +import re +from types import MethodType +from typing import Any, Sequence, cast + +from megatron.core import tensor_parallel +from megatron.core.extensions.transformer_engine import ( + TERowParallelLinear, + te_checkpoint, +) +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +import torch + +from art.megatron.lora import SelfAttentionLinearProjLoRA +from art.megatron.model_support.handlers.default_dense import ( + DefaultMoeHandler, + _compile_workaround_flags_for_provider, + _require_moe_experts, +) +from art.megatron.model_support.handlers.qwen3_common import ( + _context_parallel_world_size, +) +from art.megatron.model_support.spec import ( + CompileWorkaroundConfig, + ExpertPackedLoraGroup, + ExpertPackedLoraSlot, + LayerFamilyInstance, +) + +_GEMMA4_MOE_COMPILE_WORKAROUND_FLAGS = ( + "alltoall_dtoh", + "alltoall_dispatch_preprocess", + "deepep_dispatch_combine", + "deepep_permute_restore", + "flex_token_dispatch_combine", + "te_triton_permute_with_mask_map", +) +_ART_MOE_EXPERT_KEY_RE = re.compile( + r"^(?P.*\.mlp\.experts)\.(?P\d+)\." + r"(?Pgate_up_proj|down_proj)\.(?Plora_[AB])\.weight$" +) +_VLLM_MOE_KEY_RE = re.compile( + r"^(?P.*\.moe\.experts)\." + r"(?:(?Pbase_layer)\.)?(?Plora_[AB])\.weight$" +) +_VLLM_MOE_EXPERT_KEY_RE = re.compile( + r"^(?P.*\.moe\.experts)\.(?P\d+)\." + r"(?Pgate_proj|up_proj|down_proj)\.(?Plora_[AB])\.weight$" +) +_DENSE_MLP_LORA_KEY_RE = re.compile( + r"(?P\.mlp)\.(?Pgate_proj|up_proj|down_proj)\." + r"(?Plora_[AB])\.weight$" +) +_SHARED_EXPERT_FC1_LORA_A_KEY_RE = re.compile( + r"^.*\.layers\.(?P\d+)\.mlp\.(?:shared_expert\.)?" + r"(?:gate_proj|up_proj)\.lora_A\.weight$" +) +_SELF_ATTN_V_LORA_KEY_RE = re.compile( + r"^(?P.*\.layers\.(?P\d+)\.self_attn\.)v_proj\." + r"(?Plora_[AB]\.weight)$" +) +_SELF_ATTN_K_LORA_KEY_RE = re.compile( + r"^(?P.*\.layers\.(?P\d+)\.self_attn\.)k_proj\." + r"(?Plora_[AB]\.weight)$" +) +_MEGATRON_LAYER_RE = re.compile(r"(?:^|\.)layers\.(?P\d+)\.") +_HF_TEXT_EXPERT_KEY_RE = re.compile(r"(?P\.layers\.\d+)\.experts") + + +class Gemma4MoeHandler(DefaultMoeHandler): + key = "gemma4_moe" + is_moe = True + native_vllm_lora_status = "validated" + + def identity_lora_model_config(self, base_config: Any) -> Any: + return getattr(base_config, "text_config", base_config) + + def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: + attention_bias = kwargs.get("attention_bias") + from art.megatron.context_parallel.types import ArtContextParallelState + + module = model + while hasattr(module, "module"): + module = module.module + gpt_module = getattr(module, "language_model", module) + if isinstance(attention_bias, ArtContextParallelState): + setattr( + gpt_module, + "_art_gemma4_rotary_seq_len", + int(attention_bias.rank_plan.original_seq_len), + ) + else: + setattr(gpt_module, "_art_gemma4_rotary_seq_len", None) + return {"extra_block_kwargs": kwargs} + + def _identity_lora_parameter_suffixes( + self, + target_modules: list[str], + ) -> tuple[str, ...]: + suffixes = list(super()._identity_lora_parameter_suffixes(target_modules)) + target_set = set(target_modules) + if {"experts", "gate_proj", "up_proj"} & target_set: + suffixes.append("experts.gate_up_proj") + if {"experts", "down_proj"} & target_set: + suffixes.append("experts.down_proj") + return tuple(dict.fromkeys(suffixes)) + + def configure_provider_for_runtime(self, provider: Any) -> None: + _patch_gemma4_router_for_mcore() + _patch_gemma4_rotary_for_hf_proportional() + _patch_gemma4_qkv_for_hf_tied_value() + window_size = int(getattr(provider, "window_size", 1024)) + provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper + provider.art_flex_sliding_windows = (window_size,) + provider.art_flex_head_dims_by_window = { + None: int(getattr(provider, "global_head_dim", provider.kv_channels)), + window_size: int(provider.kv_channels), + } + provider.moe_shared_expert_overlap = False + + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: + _install_gemma4_preprocess_patch(model_chunks) + _install_gemma4_full_recompute_patch(model_chunks) + + def collect_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: + if int(getattr(provider, "num_moe_experts", 0) or 0) <= 0: + raise TypeError("Gemma 4 MoE handler received a dense provider") + sliding_count, global_count = _gemma4_attention_pattern(provider) + families = [ + LayerFamilyInstance(key="gemma4_sliding_attention", layer_index=0), + LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0), + ] + if global_count > 0: + families.append( + LayerFamilyInstance( + key="gemma4_global_attention", + layer_index=sliding_count, + ) + ) + if int(getattr(provider, "moe_shared_expert_intermediate_size", 0) or 0) > 0: + families.append( + LayerFamilyInstance(key="shared_experts_mlp", layer_index=0) + ) + return families + + def apply_lora_adapters( + self, + model_chunks: Sequence[Any], + provider: Any, + *, + target_modules: list[str], + rank: int, + alpha: int, + ) -> None: + from megatron.core.transformer.attention import SelfAttention + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.lora import ( + _adapter_model_prefix, + _is_language_transformer_layer_name, + wrap_grouped_moe_experts_3d, + wrap_shared_experts_mlp, + wrap_standard_self_attention, + ) + + target_set = set(target_modules) + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + adapter_model_prefix = _adapter_model_prefix(module) + if not isinstance(module.self_attention, SelfAttention): + raise TypeError( + "Gemma 4 expected a SelfAttention module, got " + f"{type(module.self_attention)}" + ) + attention_provider = _attention_provider_for_layer(provider, module) + qkv_targets = ( + {"q_proj", "k_proj", "v_proj"} + if not target_set + else target_set - {"o_proj"} + ) + if qkv_targets: + wrap_standard_self_attention( + module.self_attention, + adapter_model_prefix=adapter_model_prefix, + provider=attention_provider, + target_modules=qkv_targets, + rank=rank, + alpha=alpha, + ) + if ( + not target_set or {"q_proj", "k_proj", "v_proj"} & target_set + ) and _is_gemma4_global_layer(int(module.layer_number), provider): + _tie_global_value_lora_to_key(module.self_attention) + _wrap_gemma4_attention_output_lora( + module.self_attention, + adapter_model_prefix=adapter_model_prefix, + provider=attention_provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + wrap_grouped_moe_experts_3d( + _require_moe_experts(module), + adapter_model_prefix=adapter_model_prefix, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + wrap_shared_experts_mlp( + shared_experts, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + + def build_adapter_weights_by_base( + self, + model_chunks: Sequence[Any], + ) -> dict[str, list[Any]]: + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.lora import _is_language_transformer_layer_name + from art.megatron.weights.adapter_export import ( + add_grouped_moe_adapter_weights, + add_shared_experts_adapter_weights, + add_standard_self_attention_adapter_weights, + layer_base_prefix, + ) + + adapter_weights_by_base: dict[str, list[Any]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + layer_prefix = layer_base_prefix(module, module_name=module_name) + add_standard_self_attention_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + add_grouped_moe_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + experts=_require_moe_experts(module), + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + add_shared_experts_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + shared_experts=shared_experts, + ) + return adapter_weights_by_base + + def to_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + return _to_vllm_lora_tensors(tensors, adapter_config=adapter_config) + + def from_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> dict[str, torch.Tensor]: + return _from_vllm_lora_tensors(tensors, adapter_config=adapter_config) + + def expert_packed_lora_groups(self) -> tuple[ExpertPackedLoraGroup, ...]: + return ( + ExpertPackedLoraGroup( + art_group_suffix=".mlp.experts", + slots=( + ExpertPackedLoraSlot( + source_projection="gate_up_proj", + source_lora="lora_A", + output_suffix="base_layer.lora_A.weight", + pack_layout="expert_rows", + ), + ExpertPackedLoraSlot( + source_projection="gate_up_proj", + source_lora="lora_B", + output_suffix="base_layer.lora_B.weight", + pack_layout="rank_major_expert_cols", + ), + ExpertPackedLoraSlot( + source_projection="down_proj", + source_lora="lora_A", + output_suffix="lora_A.weight", + pack_layout="expert_rows", + ), + ExpertPackedLoraSlot( + source_projection="down_proj", + source_lora="lora_B", + output_suffix="lora_B.weight", + pack_layout="rank_major_expert_cols", + ), + ), + ), + ) + + def compile_workaround_config( + self, + provider: Any, + ) -> CompileWorkaroundConfig: + if bool(getattr(provider, "moe_shared_expert_overlap", False)): + return CompileWorkaroundConfig( + shared_expert_state="shared_expert_overlap", + disable_compile=True, + ) + return CompileWorkaroundConfig( + flags=_compile_workaround_flags_for_provider( + provider, + _GEMMA4_MOE_COMPILE_WORKAROUND_FLAGS, + ), + shared_expert_state="shared_experts", + disable_compile=False, + ) + + +GEMMA4_MOE_HANDLER = Gemma4MoeHandler() + +_GEMMA4_ROUTER_PATCHED = False +_GEMMA4_ROTARY_PATCHED = False +_GEMMA4_QKV_PATCHED = False + + +def _patch_gemma4_router_for_mcore() -> None: + global _GEMMA4_ROUTER_PATCHED + if _GEMMA4_ROUTER_PATCHED: + return + from megatron.bridge.models.gemma import gemma4_provider + from megatron.core.transformer.moe.router import TopKRouter + + def _art_gemma4_router_routing( + self: Any, + logits: torch.Tensor, + padding_mask: torch.Tensor | None = None, + input_ids: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + del input_ids + routing_probs, routing_map = TopKRouter.routing( + self, + logits, + padding_mask=padding_mask, + ) + if routing_map is not None: + prob_sums = routing_probs.sum(dim=-1, keepdim=True).clamp(min=1e-20) + routing_probs = routing_probs / prob_sums + routing_probs = routing_probs * self.per_expert_scale.unsqueeze(0) + return routing_probs, routing_map + + setattr(gemma4_provider.Gemma4TopKRouter, "routing", _art_gemma4_router_routing) + _GEMMA4_ROUTER_PATCHED = True + + +def _gemma4_hf_proportional_inv_freq( + *, + global_kv_channels: int, + global_rotary_percent: float, + rotary_base: int, + device: torch.device, +) -> torch.Tensor: + """HF proportional RoPE pads non-rotary pairs with zero-frequency angles.""" + rope_angles = int(global_rotary_percent * global_kv_channels // 2) + inv_freq_rotated = 1.0 / ( + rotary_base + ** ( + torch.arange(0, 2 * rope_angles, 2, dtype=torch.float32, device=device) + / global_kv_channels + ) + ) + nope_angles = global_kv_channels // 2 - rope_angles + if nope_angles <= 0: + return inv_freq_rotated + return torch.cat( + ( + inv_freq_rotated, + torch.zeros(nope_angles, dtype=torch.float32, device=device), + ), + dim=0, + ) + + +def _patch_gemma4_rotary_for_hf_proportional() -> None: + global _GEMMA4_ROTARY_PATCHED + if _GEMMA4_ROTARY_PATCHED: + return + from megatron.bridge.models.gemma import gemma4_provider + + original_init = cast(Any, gemma4_provider.Gemma4RotaryEmbedding.__init__) + + def _art_gemma4_rotary_init( + self: Any, + *, + kv_channels: int, + rotary_percent: float, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: float | None = None, + rotary_base: int = 1_000_000, + rope_scaling: bool = False, + use_cpu_initialization: bool = False, + rotary_base_local: int = 10_000, + global_kv_channels: int = 512, + global_rotary_percent: float = 0.25, + ) -> None: + original_init( + self, + kv_channels=kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + rope_scaling=rope_scaling, + use_cpu_initialization=use_cpu_initialization, + rotary_base_local=rotary_base_local, + global_kv_channels=global_kv_channels, + global_rotary_percent=global_rotary_percent, + ) + self.inv_freq = _gemma4_hf_proportional_inv_freq( + global_kv_channels=global_kv_channels, + global_rotary_percent=global_rotary_percent, + rotary_base=rotary_base, + device=self.inv_freq.device, + ) + + setattr(gemma4_provider.Gemma4RotaryEmbedding, "__init__", _art_gemma4_rotary_init) + _GEMMA4_ROTARY_PATCHED = True + + +def _patch_gemma4_qkv_for_hf_tied_value() -> None: + global _GEMMA4_QKV_PATCHED + if _GEMMA4_QKV_PATCHED: + return + from megatron.bridge.models.gemma import gemma4_provider + from megatron.core.transformer.attention import SelfAttention + + def _art_gemma4_get_query_key_value_tensors( + self: Any, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + **kwargs: Any, + ) -> tuple[Any, ...]: + result = cast( + tuple[Any, ...], + SelfAttention.get_query_key_value_tensors( + self, + hidden_states, + key_value_states, + **kwargs, + ), + ) + if len(result) < 3: + return result + query, key, value = result[0], result[1], result[2] + # HF global K=V uses the raw K projection for V before k_norm; the + # synthesized V rows are loaded from K, so V-norm should consume them here. + v_float = value.float() + rms = v_float.pow(2).mean(-1, keepdim=True).add(self._v_norm_eps).sqrt() + value = (v_float / rms).to(value.dtype) + return (query, key, value) + result[3:] + + setattr( + gemma4_provider.Gemma4SelfAttention, + "get_query_key_value_tensors", + _art_gemma4_get_query_key_value_tensors, + ) + _GEMMA4_QKV_PATCHED = True + + +def _gather_absolute_rotary_pos_emb( + table_source: torch.Tensor, + *, + position_ids: torch.Tensor, +) -> torch.Tensor: + embedding_dim = int(table_source.shape[-1]) + batch_size, sequence_length = position_ids.shape + gathered = table_source.view(table_source.shape[0], embedding_dim).index_select( + 0, + position_ids.reshape(-1), + ) + return ( + gathered.view(batch_size, sequence_length, embedding_dim) + .permute(1, 0, 2) + .contiguous() + .unsqueeze(2) + ) + + +def _install_gemma4_preprocess_patch(model_chunks: Sequence[Any]) -> None: + from megatron.core.models.gpt.gpt_model import GPTModel + + for chunk in model_chunks: + module: Any = chunk + while hasattr(module, "module"): + module = module.module + gpt_module = ( + module + if isinstance(module, GPTModel) + else cast(GPTModel, getattr(module, "language_model")) + ) + preprocess = gpt_module._preprocess + + def preprocess_hook( + *args: Any, + _gpt_module: Any = gpt_module, + _preprocess: Any = preprocess, + **kwargs: Any, + ) -> tuple[Any, ...]: + position_ids = kwargs.get("position_ids") + gemma4_rotary = getattr(_gpt_module, "rotary_pos_emb") + local_rotary = getattr(gemma4_rotary, "rope_local", None) + rotary_cp_group = getattr(gemma4_rotary, "cp_group", None) + local_rotary_cp_group = getattr(local_rotary, "cp_group", None) + uses_dispatched_local_cp_positions = ( + isinstance(position_ids, torch.Tensor) + and position_ids.ndim == 2 + and _context_parallel_world_size(getattr(_gpt_module, "config", None)) + > 1 + and (rotary_cp_group is not None or local_rotary_cp_group is not None) + ) + if uses_dispatched_local_cp_positions: + setattr(gemma4_rotary, "cp_group", None) + if local_rotary is not None: + setattr(local_rotary, "cp_group", None) + rotary_seq_len = getattr( + _gpt_module, "_art_gemma4_rotary_seq_len", None + ) + if rotary_seq_len is not None: + from megatron.core.packed_seq_params import PackedSeqParams + + kwargs = dict(kwargs) + kwargs["packed_seq_params"] = PackedSeqParams( + max_seqlen_q=int(rotary_seq_len), + max_seqlen_kv=int(rotary_seq_len), + ) + try: + preproc_output = list(_preprocess(*args, **kwargs)) + finally: + if uses_dispatched_local_cp_positions: + setattr(gemma4_rotary, "cp_group", rotary_cp_group) + if local_rotary is not None: + setattr(local_rotary, "cp_group", local_rotary_cp_group) + decoder_input = cast(torch.Tensor, preproc_output[0]) + if not decoder_input.requires_grad and decoder_input.is_leaf: + decoder_input.requires_grad_(True) + rotary_pos_emb = preproc_output[1] + if not isinstance(position_ids, torch.Tensor) or not isinstance( + rotary_pos_emb, + (tuple, list), + ): + return tuple(preproc_output) + local_table, global_table = rotary_pos_emb + if not torch.is_tensor(local_table) or not torch.is_tensor(global_table): + return tuple(preproc_output) + preproc_output[1] = ( + _gather_absolute_rotary_pos_emb( + local_table, + position_ids=position_ids, + ), + _gather_absolute_rotary_pos_emb( + global_table, + position_ids=position_ids, + ), + ) + return tuple(preproc_output) + + gpt_module._preprocess = preprocess_hook # type: ignore[attr-defined] + + +def _install_gemma4_full_recompute_patch(model_chunks: Sequence[Any]) -> None: + for chunk in model_chunks: + module: Any = chunk + while hasattr(module, "module"): + module = module.module + gpt_module = getattr(module, "language_model", module) + decoder = getattr(gpt_module, "decoder", None) + if decoder is None or getattr( + decoder, "_art_gemma4_full_recompute_patch", False + ): + continue + original_checkpointed_forward = decoder._checkpointed_forward + + def checkpointed_forward( + self: Any, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor, + context_mask: torch.Tensor, + rotary_pos_emb: Any, + attention_bias: Any, + packed_seq_params: Any, + use_inner_quantization_context: bool, + padding_mask: torch.Tensor | None = None, + extract_layer_indices: set[int] | None = None, + layer_offset: int = 0, + *, + _original_checkpointed_forward: Any = original_checkpointed_forward, + ) -> Any: + if not isinstance(rotary_pos_emb, (tuple, list)): + return _original_checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + padding_mask=padding_mask, + extract_layer_indices=extract_layer_indices, + layer_offset=layer_offset, + ) + rotary_pos_emb_local, rotary_pos_emb_global = rotary_pos_emb + if extract_layer_indices is None: + extract_layer_indices = set() + intermediate_hidden_states: list[torch.Tensor] = [] + + def custom(start: int, end: int) -> Any: + def custom_forward( + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor, + context_mask: torch.Tensor, + rotary_pos_emb_local: torch.Tensor, + rotary_pos_emb_global: torch.Tensor, + padding_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + rotary_pair = (rotary_pos_emb_local, rotary_pos_emb_global) + for index in range(start, end): + layer = self._get_layer(index) + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context( + self.config, layer.layer_number - 1 + ) + elif self.config.fp4: + inner_quantization_context = get_fp4_context( + self.config, layer.layer_number - 1 + ) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + with inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pair, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func: Any) -> Any: + args = ( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb_local, + rotary_pos_emb_global, + padding_mask, + ) + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + *args, + ) + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + *args, + ) + + if self.config.recompute_method == "uniform": + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + chunk_end = min( + layer_idx + self.config.recompute_num_layers, + self.num_layers_per_pipeline_rank, + ) + hidden_states, context = checkpoint_handler( + custom(layer_idx, chunk_end) + ) + for idx in range(layer_idx, chunk_end): + if (idx + layer_offset) in extract_layer_indices: + if idx == chunk_end - 1: + intermediate_hidden_states.append(hidden_states) + layer_idx += self.config.recompute_num_layers + elif self.config.recompute_method == "block": + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + if ( + self.config.fp8 or self.config.fp4 + ) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx + < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + 1) + ) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb_local, + rotary_pos_emb_global, + padding_mask, + ) + if (layer_idx + layer_offset) in extract_layer_indices: + intermediate_hidden_states.append(hidden_states) + else: + raise ValueError("Invalid activation recompute method.") + if len(extract_layer_indices) > 0: + return hidden_states, intermediate_hidden_states + return hidden_states + + decoder._checkpointed_forward = MethodType(checkpointed_forward, decoder) + decoder._art_gemma4_full_recompute_patch = True + + +def _gemma4_attention_pattern(provider: Any) -> tuple[int, int]: + pattern = getattr(provider, "interleaved_attn_pattern", (0, 1)) + if not pattern: + return (0, 1) + if len(pattern) == 1: + return (int(pattern[0]), 0) + return (int(pattern[0]), int(pattern[1])) + + +def _is_gemma4_global_layer(layer_number: int, provider: Any) -> bool: + layer_types = getattr(provider, "art_gemma4_layer_types", None) + if layer_types is not None: + return layer_types[int(layer_number) - 1] == "full_attention" + sliding_count, global_count = _gemma4_attention_pattern(provider) + if global_count <= 0: + return False + cycle = sliding_count + global_count + if cycle <= 0: + return False + return (layer_number - 1) % cycle >= sliding_count + + +def _gemma4_sliding_window_for_layer(provider: Any, layer_number: int) -> int | None: + if _is_gemma4_global_layer(int(layer_number), provider): + return None + return int(provider.window_size) + + +def _gemma4_flex_core_attention_wrapper( + provider: Any, base_cls: type[Any] +) -> type[Any]: + class Gemma4ArtFlexCoreAttention(base_cls): # type: ignore[misc, valid-type] + def __init__( + self, + config: Any, + layer_number: int, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(config, layer_number, *args, **kwargs) + self.art_sliding_window = _gemma4_sliding_window_for_layer( + provider, + layer_number, + ) + + return Gemma4ArtFlexCoreAttention + + +def _attention_provider_for_layer(provider: Any, module: Any) -> Any: + if not _is_gemma4_global_layer(int(module.layer_number), provider): + return provider + global_provider = copy(provider) + global_provider.kv_channels = getattr(provider, "global_head_dim") + global_provider.num_query_groups = getattr(provider, "num_global_key_value_heads") + return global_provider + + +def _tie_global_value_lora_to_key(self_attention: Any) -> None: + linear_qkv = self_attention.linear_qkv + linear_qkv.v_proj_lora = linear_qkv.k_proj_lora + + +class _Gemma4SelfAttentionLinearProjLoRA(SelfAttentionLinearProjLoRA): + def __init__( + self, + *, + adapter_model_prefix: str, + linear_proj: TERowParallelLinear, + rank: int, + alpha: int, + provider: Any, + ) -> None: + super().__init__( + adapter_model_prefix=adapter_model_prefix, + linear_proj=linear_proj, + rank=rank, + alpha=alpha, + provider=provider, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + linear_proj = self.linear_proj + base_output, bias_output = TERowParallelLinear.forward(linear_proj, x) + lora_output = self.lora(x) + if self.reduce_output and self.provider.tensor_model_parallel_size > 1: + if self.provider.sequence_parallel: + lora_output = reduce_scatter_to_sequence_parallel_region(lora_output) + else: + lora_output = reduce_from_tensor_model_parallel_region(lora_output) + output = base_output + lora_output + post_layernorm = getattr(linear_proj, "post_layernorm", None) + if post_layernorm is not None: + output = post_layernorm(output) + if isinstance(output, tuple): + output = output[0] + return output, bias_output + + +def _wrap_gemma4_attention_output_lora( + self_attention: Any, + *, + adapter_model_prefix: str, + provider: Any, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + from art.megatron.lora import _targets_include, _unwrap_attr + + if not _targets_include(target_modules, "o_proj"): + return + linear_proj = _unwrap_attr( + self_attention.linear_proj, + "linear_proj", + TERowParallelLinear, + ) + self_attention.linear_proj = _Gemma4SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", + linear_proj=linear_proj, + rank=rank, + alpha=alpha, + provider=provider, + ) + + +def _to_vllm_key(key: str) -> str: + key = key.replace(".mlp.shared_expert.", ".mlp.").replace( + ".mlp.experts", + ".moe.experts", + ) + return _HF_TEXT_EXPERT_KEY_RE.sub(r"\g.moe.experts", key) + + +def _from_vllm_key(key: str) -> str: + key = key.replace(".moe.experts", ".mlp.experts") + return _DENSE_MLP_LORA_KEY_RE.sub( + r"\g.shared_expert.\g.\g.weight", + key, + ) + + +def _pack_vllm_3d_lora_b(blocks: list[torch.Tensor]) -> torch.Tensor: + stacked = torch.stack(blocks, dim=0) + return stacked.permute(1, 2, 0).reshape(stacked.shape[1], -1).contiguous() + + +def _unpack_vllm_3d_lora_b( + tensor: torch.Tensor, + *, + num_experts: int, + rank: int, +) -> torch.Tensor: + return tensor.reshape(tensor.shape[0], rank, num_experts).permute(2, 0, 1) + + +def _clone(tensor: torch.Tensor) -> torch.Tensor: + return tensor.clone().contiguous() + + +@lru_cache(maxsize=8) +def _gemma4_text_config_dict(base_model_name_or_path: str) -> dict[str, Any]: + config_path = Path(base_model_name_or_path) / "config.json" + if not config_path.exists(): + from huggingface_hub import hf_hub_download + + config_path = Path( + hf_hub_download( + base_model_name_or_path, + "config.json", + local_files_only=True, + ) + ) + config = json.loads(config_path.read_text(encoding="utf-8")) + return dict(config.get("text_config") or config) + + +def _gemma4_k_eq_v_layers(adapter_config: dict[str, Any]) -> set[int]: + base_model = str(adapter_config["base_model_name_or_path"]) + config = _gemma4_text_config_dict(base_model) + if not bool(config.get("attention_k_eq_v", False)): + return set() + return { + layer_idx + for layer_idx, layer_type in enumerate(config["layer_types"]) + if layer_type == "full_attention" + } + + +def _gemma4_hf_file(base_model_name_or_path: str, filename: str) -> Path: + base_path = Path(base_model_name_or_path) + if base_path.exists(): + return base_path / filename + from huggingface_hub import hf_hub_download + + return Path( + hf_hub_download( + base_model_name_or_path, + filename, + local_files_only=True, + ) + ) + + +@lru_cache(maxsize=8) +def _gemma4_shared_expert_prenorm_corrections( + base_model_name_or_path: str, +) -> tuple[torch.Tensor, ...]: + from safetensors import safe_open + + index = json.loads( + _gemma4_hf_file( + base_model_name_or_path, + "model.safetensors.index.json", + ).read_text(encoding="utf-8") + ) + weight_map = dict(index["weight_map"]) + text_config = _gemma4_text_config_dict(base_model_name_or_path) + num_layers = int(text_config["num_hidden_layers"]) + norm_keys_by_file: dict[str, list[tuple[int, str, str]]] = {} + + for layer in range(num_layers): + for suffix in ( + "pre_feedforward_layernorm", + "pre_feedforward_layernorm_2", + ): + candidates = ( + f"model.language_model.layers.{layer}.{suffix}.weight", + f"model.layers.{layer}.{suffix}.weight", + ) + key = next(candidate for candidate in candidates if candidate in weight_map) + norm_keys_by_file.setdefault(weight_map[key], []).append( + (layer, suffix, key) + ) + norm_weights: dict[tuple[int, str], torch.Tensor] = {} + for filename, entries in norm_keys_by_file.items(): + with safe_open( + _gemma4_hf_file(base_model_name_or_path, filename), + framework="pt", + device="cpu", + ) as handle: + for layer, suffix, key in entries: + norm_weights[(layer, suffix)] = handle.get_tensor(key).float() + + return tuple( + norm_weights[(layer, "pre_feedforward_layernorm")] + / norm_weights[(layer, "pre_feedforward_layernorm_2")] + for layer in range(num_layers) + ) + + +def _shared_expert_fc1_prenorm_correction( + *, + adapter_config: dict[str, Any], + layer: int, + device: torch.device, +) -> torch.Tensor: + # Megatron Bridge folds pffl/pffl2 into shared-expert FC1 base weights because + # MCore feeds pffl2-normalized activations while HF/vLLM feeds pffl-normalized + # activations. LoRA-A needs the same basis change at the HF/vLLM boundary. + return _gemma4_shared_expert_prenorm_corrections( + str(adapter_config["base_model_name_or_path"]) + )[layer].to(device=device) + + +def _rescale_shared_expert_fc1_lora_a( + key: str, + tensor: torch.Tensor, + *, + adapter_config: dict[str, Any], + to_vllm: bool, +) -> torch.Tensor: + match = _SHARED_EXPERT_FC1_LORA_A_KEY_RE.match(key) + if match is None: + return tensor + correction = _shared_expert_fc1_prenorm_correction( + adapter_config=adapter_config, + layer=int(match.group("layer")), + device=tensor.device, + ) + factor = correction.reciprocal() if to_vllm else correction + return (tensor.float() * factor.unsqueeze(0)).to(tensor.dtype).contiguous() + + +def _drop_gemma4_k_eq_v_v_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> dict[str, torch.Tensor]: + k_eq_v_layers = _gemma4_k_eq_v_layers(adapter_config) + if not k_eq_v_layers: + return tensors + return { + key: tensor + for key, tensor in tensors.items() + if not ( + (match := _SELF_ATTN_V_LORA_KEY_RE.match(key)) is not None + and int(match.group("layer")) in k_eq_v_layers + ) + } + + +def _add_gemma4_k_eq_v_v_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> dict[str, torch.Tensor]: + k_eq_v_layers = _gemma4_k_eq_v_layers(adapter_config) + if not k_eq_v_layers: + return tensors + transformed = dict(tensors) + for key, tensor in tensors.items(): + match = _SELF_ATTN_K_LORA_KEY_RE.match(key) + if match is None or int(match.group("layer")) not in k_eq_v_layers: + continue + v_key = f"{match.group('prefix')}v_proj.{match.group('suffix')}" + if v_key not in transformed: + transformed[v_key] = tensor.clone().contiguous() + return transformed + + +def _vllm_moe_config(adapter_config: dict[str, Any]) -> dict[str, Any]: + config = dict(adapter_config) + target_modules = list(config.get("target_modules") or []) + if "experts" not in target_modules: + target_modules.append("experts") + config["target_modules"] = target_modules + return config + + +def _group_art_moe_tensors( + tensors: dict[str, torch.Tensor], +) -> dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]]: + grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} + for key, tensor in tensors.items(): + match = _ART_MOE_EXPERT_KEY_RE.match(key) + if match is None: + continue + grouped.setdefault(match.group("prefix"), {}).setdefault( + int(match.group("expert")), + {}, + ).setdefault(match.group("module"), {})[match.group("lora")] = tensor + return grouped + + +def _to_vllm_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + grouped = _group_art_moe_tensors(tensors) + if not grouped: + transformed = { + vllm_key: _rescale_shared_expert_fc1_lora_a( + vllm_key, + tensor, + adapter_config=adapter_config, + to_vllm=True, + ) + for key, tensor in tensors.items() + for vllm_key in (_to_vllm_key(key),) + } + if len(transformed) != len(tensors): + raise RuntimeError("Duplicate Gemma 4 LoRA tensor after vLLM conversion") + transformed = _add_gemma4_k_eq_v_v_lora_tensors( + transformed, + adapter_config=adapter_config, + ) + has_fused_experts = any(_VLLM_MOE_KEY_RE.match(key) for key in transformed) + return ( + transformed, + _vllm_moe_config(adapter_config) if has_fused_experts else adapter_config, + ) + + transformed: dict[str, torch.Tensor] = {} + used_keys: set[str] = set() + for prefix, experts in grouped.items(): + vllm_prefix = _to_vllm_key(prefix) + gate_up_a: list[torch.Tensor] = [] + gate_up_b: list[torch.Tensor] = [] + down_a: list[torch.Tensor] = [] + down_b: list[torch.Tensor] = [] + for expert in sorted(experts): + modules = experts[expert] + try: + gate_up_a_tensor = modules["gate_up_proj"]["lora_A"] + gate_up_b_tensor = modules["gate_up_proj"]["lora_B"] + down_a_tensor = modules["down_proj"]["lora_A"] + down_b_tensor = modules["down_proj"]["lora_B"] + except KeyError as exc: + raise RuntimeError( + f"Incomplete Gemma 4 MoE LoRA block for {prefix}.{expert}" + ) from exc + gate_up_a.append(gate_up_a_tensor.contiguous()) + gate_up_b.append(gate_up_b_tensor.contiguous()) + down_a.append(down_a_tensor.contiguous()) + down_b.append(down_b_tensor.contiguous()) + for module_name in ("gate_up_proj", "down_proj"): + for lora_name in ("lora_A", "lora_B"): + used_keys.add(f"{prefix}.{expert}.{module_name}.{lora_name}.weight") + transformed[f"{vllm_prefix}.base_layer.lora_A.weight"] = torch.cat( + gate_up_a, + dim=0, + ).contiguous() + transformed[f"{vllm_prefix}.base_layer.lora_B.weight"] = _pack_vllm_3d_lora_b( + gate_up_b + ) + transformed[f"{vllm_prefix}.lora_A.weight"] = torch.cat( + down_a, + dim=0, + ).contiguous() + transformed[f"{vllm_prefix}.lora_B.weight"] = _pack_vllm_3d_lora_b(down_b) + + for key, tensor in tensors.items(): + if key in used_keys: + continue + vllm_key = _to_vllm_key(key) + if vllm_key in transformed: + raise RuntimeError( + f"Duplicate Gemma 4 LoRA tensor after conversion: {vllm_key}" + ) + transformed[vllm_key] = _rescale_shared_expert_fc1_lora_a( + vllm_key, + tensor, + adapter_config=adapter_config, + to_vllm=True, + ) + transformed = _add_gemma4_k_eq_v_v_lora_tensors( + transformed, + adapter_config=adapter_config, + ) + return transformed, _vllm_moe_config(adapter_config) + + +def _from_vllm_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> dict[str, torch.Tensor]: + expert_grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} + for key, tensor in tensors.items(): + match = _VLLM_MOE_EXPERT_KEY_RE.match(key) + if match is None: + continue + expert_grouped.setdefault(match.group("prefix"), {}).setdefault( + int(match.group("expert")), + {}, + ).setdefault(match.group("module"), {})[match.group("lora")] = tensor + if expert_grouped: + return _drop_gemma4_k_eq_v_v_lora_tensors( + _from_vllm_per_expert_lora_tensors( + tensors, + expert_grouped=expert_grouped, + adapter_config=adapter_config, + ), + adapter_config=adapter_config, + ) + + grouped: dict[str, dict[str, torch.Tensor]] = {} + for key, tensor in tensors.items(): + match = _VLLM_MOE_KEY_RE.match(key) + if match is None: + continue + slot = ( + f"{'base_layer.' if match.group('base_layer') else ''}{match.group('lora')}" + ) + grouped.setdefault(match.group("prefix"), {})[slot] = tensor + if not grouped: + return _drop_gemma4_k_eq_v_v_lora_tensors( + { + art_key: _rescale_shared_expert_fc1_lora_a( + art_key, + tensor, + adapter_config=adapter_config, + to_vllm=False, + ) + for key, tensor in tensors.items() + for art_key in (_from_vllm_key(key),) + }, + adapter_config=adapter_config, + ) + + rank = int(adapter_config["r"]) + transformed: dict[str, torch.Tensor] = {} + used_keys: set[str] = set() + for prefix, slots in grouped.items(): + try: + gate_up_a = slots["base_layer.lora_A"] + gate_up_b = slots["base_layer.lora_B"] + down_a = slots["lora_A"] + down_b = slots["lora_B"] + except KeyError as exc: + raise RuntimeError( + f"Incomplete Gemma 4 vLLM MoE LoRA block for {prefix}" + ) from exc + if gate_up_a.shape[0] % rank != 0: + raise RuntimeError( + f"{prefix}: gate/up lora_A shape {tuple(gate_up_a.shape)} " + f"is not divisible by rank {rank}" + ) + num_experts = gate_up_a.shape[0] // rank + art_prefix = _from_vllm_key(prefix) + gate_up_b_by_expert = _unpack_vllm_3d_lora_b( + gate_up_b, + num_experts=num_experts, + rank=rank, + ) + down_b_by_expert = _unpack_vllm_3d_lora_b( + down_b, + num_experts=num_experts, + rank=rank, + ) + for expert in range(num_experts): + row = expert * rank + transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_A.weight"] = ( + gate_up_a[row : row + rank].contiguous() + ) + transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_B.weight"] = ( + gate_up_b_by_expert[expert].contiguous() + ) + transformed[f"{art_prefix}.{expert}.down_proj.lora_A.weight"] = down_a[ + row : row + rank + ].contiguous() + transformed[f"{art_prefix}.{expert}.down_proj.lora_B.weight"] = ( + down_b_by_expert[expert].contiguous() + ) + used_keys.update( + { + f"{prefix}.base_layer.lora_A.weight", + f"{prefix}.base_layer.lora_B.weight", + f"{prefix}.lora_A.weight", + f"{prefix}.lora_B.weight", + } + ) + for key, tensor in tensors.items(): + if key in used_keys: + continue + art_key = _from_vllm_key(key) + if art_key in transformed: + raise RuntimeError( + f"Duplicate Gemma 4 LoRA tensor after conversion: {art_key}" + ) + transformed[art_key] = _rescale_shared_expert_fc1_lora_a( + art_key, + tensor, + adapter_config=adapter_config, + to_vllm=False, + ) + return _drop_gemma4_k_eq_v_v_lora_tensors( + transformed, + adapter_config=adapter_config, + ) + + +def _from_vllm_per_expert_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + expert_grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]], + adapter_config: dict[str, Any], +) -> dict[str, torch.Tensor]: + transformed: dict[str, torch.Tensor] = {} + used_keys: set[str] = set() + for prefix, experts in expert_grouped.items(): + art_prefix = _from_vllm_key(prefix) + for expert, modules in experts.items(): + try: + gate_a = modules["gate_proj"]["lora_A"] + gate_b = modules["gate_proj"]["lora_B"] + up_a = modules["up_proj"]["lora_A"] + up_b = modules["up_proj"]["lora_B"] + down_a = modules["down_proj"]["lora_A"] + down_b = modules["down_proj"]["lora_B"] + except KeyError as exc: + raise RuntimeError( + f"Incomplete Gemma 4 vLLM MoE LoRA block for {prefix}.{expert}" + ) from exc + if not torch.equal(gate_a, up_a): + raise RuntimeError( + "Gemma 4 Megatron gate_up_proj requires gate/up LoRA-A " + f"tensors to match for {prefix}.{expert}" + ) + transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_A.weight"] = _clone( + gate_a + ) + transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_B.weight"] = ( + torch.cat([gate_b, up_b], dim=0).contiguous() + ) + transformed[f"{art_prefix}.{expert}.down_proj.lora_A.weight"] = _clone( + down_a + ) + transformed[f"{art_prefix}.{expert}.down_proj.lora_B.weight"] = _clone( + down_b + ) + for module_name in ("gate_proj", "up_proj", "down_proj"): + for lora_name in ("lora_A", "lora_B"): + used_keys.add(f"{prefix}.{expert}.{module_name}.{lora_name}.weight") + for key, tensor in tensors.items(): + if key in used_keys: + continue + if _VLLM_MOE_KEY_RE.match(key) is not None: + raise RuntimeError( + "Mixed fused and per-expert Gemma 4 vLLM MoE LoRA tensors" + ) + art_key = _from_vllm_key(key) + transformed[art_key] = _rescale_shared_expert_fc1_lora_a( + art_key, + tensor, + adapter_config=adapter_config, + to_vllm=False, + ) + return transformed + + +def _gemma4_text_only_mapping_registry(hf_config: Any | None = None) -> Any: + from megatron.bridge.models.conversion.mapping_registry import ( + MegatronMappingRegistry, + ) + from megatron.bridge.models.gemma.gemma4_bridge import _Gemma4QKVMapping + from megatron.bridge.models.gemma_vl.gemma4_vl_bridge import Gemma4VLBridge + + upstream_registry = Gemma4VLBridge().mapping_registry() + global_layer_indices = _gemma4_global_layer_indices(hf_config) + ( + bridge_gate_up_mapping, + bridge_down_mapping, + art_gate_up_mapping, + art_down_mapping, + ) = _art_gemma4_expert_mapping_types() + + class _ArtGemma4TextOnlyQKVMapping(_Gemma4QKVMapping): + def __init__( + self, + megatron_param: str, + q: str, + k: str, + v: str, + *, + global_layer_indices: tuple[int, ...], + ) -> None: + super().__init__(megatron_param, q, k, v) + self._global_layer_indices = global_layer_indices + self._export_hf_param = dict(cast(dict[str, str], self.hf_param)) + + def resolve(self, captures: tuple[str, ...]) -> Any: + megatron_param, hf_param = self._resolve_names(captures) + hf_param = cast(dict[str, str], hf_param) + resolved = type(self)( + megatron_param, + hf_param["q"], + hf_param["k"], + hf_param["v"], + global_layer_indices=self._global_layer_indices, + ) + layer_index = _megatron_layer_index(megatron_param) + if layer_index in self._global_layer_indices: + resolved_hf_param = dict(cast(dict[str, str], resolved.hf_param)) + resolved_hf_param["v"] = resolved_hf_param["k"] + resolved.hf_param = resolved_hf_param + return resolved + + def megatron_to_hf( + self, + megatron_weights: torch.Tensor | None, + megatron_module: Any | None, + ) -> dict[str, torch.Tensor]: + import_hf_param = self.hf_param + self.hf_param = self._export_hf_param + try: + return super().megatron_to_hf(megatron_weights, megatron_module) + finally: + self.hf_param = import_hf_param + + language_mappings = [ + _text_only_gemma4_mapping( + mapping, + qkv_mapping_type=_ArtGemma4TextOnlyQKVMapping, + bridge_gate_up_mapping=bridge_gate_up_mapping, + bridge_down_mapping=bridge_down_mapping, + art_gate_up_mapping=art_gate_up_mapping, + art_down_mapping=art_down_mapping, + global_layer_indices=global_layer_indices, + ) + for mapping in upstream_registry.mappings + if mapping.megatron_param.startswith("language_model.") + ] + return MegatronMappingRegistry(*language_mappings) + + +def _text_only_gemma4_mapping( + mapping: Any, + *, + qkv_mapping_type: type[Any], + bridge_gate_up_mapping: type[Any], + bridge_down_mapping: type[Any], + art_gate_up_mapping: type[Any], + art_down_mapping: type[Any], + global_layer_indices: tuple[int, ...], +) -> Any: + megatron_param = mapping.megatron_param.removeprefix("language_model.") + hf_param = getattr(mapping, "hf_param", None) + if isinstance(mapping, bridge_gate_up_mapping): + return art_gate_up_mapping(megatron_param, hf_param) + if isinstance(mapping, bridge_down_mapping): + return art_down_mapping(megatron_param, hf_param) + if ( + megatron_param.endswith(".self_attention.linear_qkv.weight") + and isinstance(hf_param, dict) + and set(hf_param) == {"q", "k", "v"} + ): + return qkv_mapping_type( + megatron_param, + hf_param["q"], + hf_param["k"], + hf_param["v"], + global_layer_indices=global_layer_indices, + ) + cloned = copy(mapping) + cloned.megatron_param = megatron_param + return cloned + + +def _art_gemma4_expert_mapping_types() -> tuple[ + type[Any], type[Any], type[Any], type[Any] +]: + from megatron.bridge.models.conversion.param_mapping import ( + ColumnParallelMapping, + FusedExpertMapping, + FusedGatedExpertMapping, + RowParallelMapping, + _align_expert_weight_to_shape, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + ) + from megatron.bridge.utils.common_utils import extract_expert_number_from_param + + class _ArtGemma4ExpertGateUpProjMapping(FusedGatedExpertMapping): + def hf_to_megatron( + self, + hf_weights: Any, + megatron_module: Any, + ) -> torch.Tensor: + global_expert_number = extract_expert_number_from_param(self.megatron_param) + expert_weight = _select_gemma4_expert_weight( + hf_weights, + global_expert_number=global_expert_number, + ep_size=int(self.ep_size), + ) + normalized_param = self._normalize_expert_param_name(self.megatron_param) + target_param = get_module_and_param_from_name( + megatron_module, normalized_param + )[1] + full_target_shape = ( + target_param.shape[0] * self.tp_size, + target_param.shape[1], + ) + gate_target_shape = ( + full_target_shape[0] // 2, + full_target_shape[1], + ) + if full_target_shape[0] % 2 != 0: + raise ValueError( + f"Expected even fused dim for {self.megatron_param}, got {full_target_shape}." + ) + if ( + isinstance(expert_weight, torch.Tensor) + and expert_weight.ndim == 3 + and expert_weight.shape[0] == 2 + ): + gate = _align_expert_weight_to_shape( + expert_weight[0], torch.Size(gate_target_shape), "gate" + ) + up = _align_expert_weight_to_shape( + expert_weight[1], torch.Size(gate_target_shape), "up" + ) + else: + fused = _align_expert_weight_to_shape( + cast(torch.Tensor, expert_weight), + torch.Size(full_target_shape), + "gate_up", + ) + gate, up = torch.chunk(fused, 2, dim=0) + return self._gated_mapping.hf_to_megatron( + {"gate": gate, "up": up}, + megatron_module, + ) + + class _ArtGemma4ExpertDownProjMapping(FusedExpertMapping): + def hf_to_megatron( + self, + hf_weights: Any, + megatron_module: Any, + ) -> torch.Tensor: + global_expert_number = extract_expert_number_from_param(self.megatron_param) + expert_weight = _select_gemma4_expert_weight( + hf_weights, + global_expert_number=global_expert_number, + ep_size=int(self.ep_size), + ) + normalized_param = self._normalize_expert_param_name(self.megatron_param) + target_param = get_module_and_param_from_name( + megatron_module, normalized_param + )[1] + if self._mapping is None: + self._detected_type = self._detect_parallelism_type(megatron_module) + self._mapping = self._get_or_create_mapping(self._detected_type) + if isinstance(self._mapping, ColumnParallelMapping): + full_target_shape = ( + target_param.shape[0] * self.tp_size, + target_param.shape[1], + ) + elif isinstance(self._mapping, RowParallelMapping): + full_target_shape = ( + target_param.shape[0], + target_param.shape[1] * self.tp_size, + ) + else: + full_target_shape = tuple(target_param.shape) + aligned = _align_expert_weight_to_shape( + expert_weight, + torch.Size(full_target_shape), + "down_proj", + ) + return self._mapping.hf_to_megatron(aligned, megatron_module) + + return ( + FusedGatedExpertMapping, + FusedExpertMapping, + _ArtGemma4ExpertGateUpProjMapping, + _ArtGemma4ExpertDownProjMapping, + ) + + +def _select_gemma4_expert_weight( + hf_weights: Any, + *, + global_expert_number: int, + ep_size: int, +) -> Any: + from art.megatron.runtime.bridge_runtime import ExpertTensorSlice + + if isinstance(hf_weights, ExpertTensorSlice): + return hf_weights.get(global_expert_number) + if isinstance(hf_weights, torch.Tensor) and hf_weights.ndim >= 3: + if ep_size > 1: + raise RuntimeError( + "Gemma 4 EP expert loading expected a sliced fused-expert " + "HF tensor, but received the full all-expert tensor for " + f"global expert {global_expert_number}." + ) + return hf_weights[global_expert_number] + return hf_weights + + +def _gemma4_global_layer_indices(hf_config: Any | None) -> tuple[int, ...]: + text_config = getattr(hf_config, "text_config", hf_config) + layer_types = getattr(text_config, "layer_types", None) + if not layer_types: + return () + return tuple( + layer_index + for layer_index, layer_type in enumerate(layer_types) + if layer_type == "full_attention" + ) + + +def _megatron_layer_index(megatron_param: str) -> int | None: + match = _MEGATRON_LAYER_RE.search(megatron_param) + return None if match is None else int(match.group("layer")) + + +_GEMMA4_TEXT_ONLY_BRIDGE_REGISTERED = False + + +def ensure_gemma4_text_only_bridge_registered() -> None: + global _GEMMA4_TEXT_ONLY_BRIDGE_REGISTERED + if _GEMMA4_TEXT_ONLY_BRIDGE_REGISTERED: + return + + from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge + from megatron.bridge.models.conversion.transformers_compat import ( + rope_local_base_freq_from_hf, + rope_theta_from_hf, + ) + from megatron.bridge.models.gemma.gemma4_bridge import ( + Gemma4Bridge, + _infer_attn_pattern, + ) + from megatron.bridge.models.gemma.gemma4_provider import Gemma4ModelProvider + from megatron.bridge.models.gemma_vl.gemma4_vl_bridge import Gemma4VLBridge + from megatron.core.models.gpt.gpt_model import GPTModel + + @MegatronModelBridge.register_bridge( + source="Gemma4ForConditionalGeneration", + target=GPTModel, + provider=Gemma4ModelProvider, + model_type="gemma4", + ) + class _ArtGemma4TextOnlyBridge(Gemma4Bridge): + def maybe_modify_converted_hf_weight( + self, + task: Any, + converted_weights_dict: Any, + hf_state_dict: Any, + ) -> Any: + return cast(Any, Gemma4VLBridge).maybe_modify_converted_hf_weight( + self, + task, + converted_weights_dict, + hf_state_dict, + ) + + def maybe_modify_loaded_hf_weight( + self, + hf_param: str | dict[str, str], + hf_state_dict: Any, + ) -> Any: + if isinstance(hf_param, dict) and "v" in hf_param: + v_name = hf_param["v"] + if v_name not in hf_state_dict: + k_name = hf_param["k"] + return { + role: ( + hf_state_dict[k_name].clone() + if role == "v" + else hf_state_dict[name] + ) + for role, name in hf_param.items() + } + if isinstance(hf_param, dict) and "gate" in hf_param: + gate_name = hf_param["gate"] + if "mlp.gate_proj" in gate_name: + return cast(Any, Gemma4VLBridge)._fuse_shared_expert_prenorm( + self, + hf_param, + hf_state_dict, + ) + if isinstance(hf_param, str) and hf_param.endswith("router.proj.weight"): + return cast(Any, Gemma4VLBridge)._fuse_router_weight( + self, + hf_param, + hf_state_dict, + ) + return super().maybe_modify_loaded_hf_weight(hf_param, hf_state_dict) + + def provider_bridge(self, hf_pretrained: Any) -> Any: + text_config = getattr( + hf_pretrained.config, + "text_config", + hf_pretrained.config, + ) + if ( + not getattr(text_config, "enable_moe_block", False) + or int(getattr(text_config, "hidden_size_per_layer_input", 0) or 0) > 0 + ): + raise ValueError( + "ART Gemma 4 support currently targets the MoE text backbone " + "without per-layer embeddings." + ) + + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) + provider = Gemma4ModelProvider(**provider_kwargs) + provider.window_size = getattr(text_config, "sliding_window", 1024) + provider.rotary_base = ( + rope_local_base_freq_from_hf(text_config), + rope_theta_from_hf(text_config), + ) + provider.softmax_scale = 1.0 + provider.kv_channels = getattr(text_config, "head_dim", 256) + provider.qk_layernorm = True + provider.global_head_dim = getattr(text_config, "global_head_dim", 512) + provider.num_global_key_value_heads = getattr( + text_config, + "num_global_key_value_heads", + 2, + ) + provider.attention_k_eq_v = getattr(text_config, "attention_k_eq_v", False) + rope_params = getattr(text_config, "rope_parameters", {}) + if isinstance(rope_params, dict): + full_attn_rope = rope_params.get("full_attention", {}) + provider.global_rotary_percent = full_attn_rope.get( + "partial_rotary_factor", + 0.25, + ) + layer_types = getattr(text_config, "layer_types", None) + if layer_types: + setattr(provider, "art_gemma4_layer_types", tuple(layer_types)) + provider.interleaved_attn_pattern = _infer_attn_pattern(layer_types) + + provider.num_moe_experts = getattr(text_config, "num_experts", 128) + provider.moe_router_topk = getattr(text_config, "top_k_experts", 8) + provider.moe_ffn_hidden_size = getattr( + text_config, + "moe_intermediate_size", + 704, + ) + provider.moe_shared_expert_intermediate_size = getattr( + text_config, + "intermediate_size", + 2112, + ) + provider.moe_shared_expert_overlap = False + provider.moe_shared_expert_gate = False + provider.moe_layer_freq = 1 + provider.final_logit_softcapping = getattr( + text_config, + "final_logit_softcapping", + 30.0, + ) + provider.bf16 = True + provider.params_dtype = torch.bfloat16 + provider.autocast_dtype = torch.bfloat16 + provider.make_vocab_size_divisible_by = 128 + return provider + + def mapping_registry(self) -> Any: + return _gemma4_text_only_mapping_registry(getattr(self, "hf_config", None)) + + _GEMMA4_TEXT_ONLY_BRIDGE_REGISTERED = True diff --git a/src/art/megatron/model_support/handlers/qwen3_common.py b/src/art/megatron/model_support/handlers/qwen3_common.py index f00a4fbf8..02497106d 100644 --- a/src/art/megatron/model_support/handlers/qwen3_common.py +++ b/src/art/megatron/model_support/handlers/qwen3_common.py @@ -47,6 +47,25 @@ def _build_absolute_rotary_pos_emb( return absolute_rotary_pos_emb +def qwen3_forward_kwargs(model: Any, **kwargs: Any) -> dict[str, Any]: + attention_bias = kwargs.get("attention_bias") + from art.megatron.context_parallel.types import ArtContextParallelState + + module = model + while hasattr(module, "module"): + module = module.module + gpt_module = getattr(module, "language_model", module) + if isinstance(attention_bias, ArtContextParallelState): + setattr( + gpt_module, + "_art_qwen3_rotary_seq_len", + int(attention_bias.rank_plan.original_seq_len), + ) + else: + setattr(gpt_module, "_art_qwen3_rotary_seq_len", None) + return {"extra_block_kwargs": kwargs} + + def install_qwen3_text_preprocess_patch(model_chunks: Sequence[Any]) -> None: from megatron.core.models.gpt.gpt_model import GPTModel import torch @@ -94,9 +113,13 @@ def preprocess_hook(*args, _preprocess=preprocess, **kwargs): and getattr(gpt_module, "position_embedding_type", None) == "rope" and cp_world_size > 1 ): + rotary_seq_len = cast( + int, + getattr(gpt_module, "_art_qwen3_rotary_seq_len", None), + ) table_source = _build_absolute_rotary_pos_emb( gpt_module, - max_position=int(position_ids.max().item()), + max_position=int(rotary_seq_len) - 1, dtype=table.dtype, device=table.device, ) diff --git a/src/art/megatron/model_support/handlers/qwen3_dense.py b/src/art/megatron/model_support/handlers/qwen3_dense.py index 5cf76e222..b969c7274 100644 --- a/src/art/megatron/model_support/handlers/qwen3_dense.py +++ b/src/art/megatron/model_support/handlers/qwen3_dense.py @@ -3,6 +3,7 @@ from art.megatron.model_support.handlers.default_dense import DefaultDenseHandler from art.megatron.model_support.handlers.qwen3_common import ( install_qwen3_text_preprocess_patch, + qwen3_forward_kwargs, ) @@ -13,5 +14,8 @@ class Qwen3DenseHandler(DefaultDenseHandler): def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: install_qwen3_text_preprocess_patch(model_chunks) + def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: + return qwen3_forward_kwargs(model, **kwargs) + QWEN3_DENSE_HANDLER = Qwen3DenseHandler() diff --git a/src/art/megatron/model_support/handlers/qwen3_moe.py b/src/art/megatron/model_support/handlers/qwen3_moe.py index 8eb58d28a..548419d94 100644 --- a/src/art/megatron/model_support/handlers/qwen3_moe.py +++ b/src/art/megatron/model_support/handlers/qwen3_moe.py @@ -9,6 +9,7 @@ ) from art.megatron.model_support.handlers.qwen3_common import ( install_qwen3_text_preprocess_patch, + qwen3_forward_kwargs, ) from art.megatron.model_support.spec import CompileWorkaroundConfig @@ -37,6 +38,9 @@ def to_vllm_lora_tensors( def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: install_qwen3_text_preprocess_patch(model_chunks) + def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: + return qwen3_forward_kwargs(model, **kwargs) + def compile_workaround_config( self, provider: Any, diff --git a/src/art/megatron/model_support/registry.py b/src/art/megatron/model_support/registry.py index 73bc82c09..cc518e9d9 100644 --- a/src/art/megatron/model_support/registry.py +++ b/src/art/megatron/model_support/registry.py @@ -12,7 +12,9 @@ _QWEN3_MOE_HANDLER_KEY = "qwen3_moe" _QWEN3_5_DENSE_HANDLER_KEY = "qwen3_5_dense" _QWEN3_5_MOE_HANDLER_KEY = "qwen3_5_moe" +_GEMMA4_MOE_HANDLER_KEY = "gemma4_moe" _VALIDATED_NATIVE_VLLM_LORA_STATUS: NativeVllmLoraStatus = "validated" +_WIP_NATIVE_VLLM_LORA_STATUS: NativeVllmLoraStatus = "wip" _DISABLED_NATIVE_VLLM_LORA_STATUS: NativeVllmLoraStatus = "disabled" _DENSE_TARGET_MODULES = ( @@ -26,6 +28,7 @@ ) _QWEN3_MOE_TARGET_MODULES = (*_DENSE_TARGET_MODULES, "experts") +_GEMMA4_MOE_TARGET_MODULES = (*_DENSE_TARGET_MODULES, "experts") _QWEN3_5_DENSE_TARGET_MODULES = ( "q_proj", @@ -129,13 +132,29 @@ ), ) +GEMMA4_MOE_SPEC = ModelSupportSpec( + key="gemma4_moe", + handler_key=_GEMMA4_MOE_HANDLER_KEY, + is_moe=True, + model_names=( + "google/gemma-4-26B-A4B", + "google/gemma-4-26B-A4B-it", + ), + default_target_modules=_GEMMA4_MOE_TARGET_MODULES, + native_vllm_lora_status=_WIP_NATIVE_VLLM_LORA_STATUS, + dependency_floor=DependencyFloor( + transformers="5.6.2", + megatron_bridge="e1a207ac757e5d0ed94d8ffbe1cbd28e81d8c084", + ), +) + VALIDATED_MODEL_SUPPORT_SPECS = ( QWEN3_MOE_SPEC, QWEN3_DENSE_SPEC, QWEN3_5_MOE_SPEC, QWEN3_5_DENSE_SPEC, ) -PROBE_ONLY_MODEL_SUPPORT_SPECS = () +PROBE_ONLY_MODEL_SUPPORT_SPECS = (GEMMA4_MOE_SPEC,) _ALL_MODEL_SUPPORT_SPECS = ( DEFAULT_DENSE_SPEC, *VALIDATED_MODEL_SUPPORT_SPECS, @@ -173,6 +192,10 @@ "art.megatron.model_support.handlers.qwen3_5", "QWEN3_5_MOE_HANDLER", ), + _GEMMA4_MOE_HANDLER_KEY: ( + "art.megatron.model_support.handlers.gemma4", + "GEMMA4_MOE_HANDLER", + ), } _BRIDGE_REGISTRATION_IMPORTS: dict[str, tuple[str, str]] = { "qwen3_5_dense": ( @@ -183,6 +206,10 @@ "art.megatron.model_support.handlers.qwen3_5", "ensure_qwen35_text_only_bridge_registered", ), + "gemma4_moe": ( + "art.megatron.model_support.handlers.gemma4", + "ensure_gemma4_text_only_bridge_registered", + ), } _HANDLERS_BY_KEY: dict[str, ModelSupportHandler] = {} _REGISTERED_BRIDGE_KEYS: set[str] = set() @@ -192,6 +219,7 @@ QWEN3_5_DENSE_MODELS = frozenset(QWEN3_5_DENSE_SPEC.model_names) QWEN3_5_MOE_MODELS = frozenset(QWEN3_5_MOE_SPEC.model_names) QWEN3_5_MODELS = QWEN3_5_DENSE_MODELS | QWEN3_5_MOE_MODELS +GEMMA4_MOE_MODELS = frozenset(GEMMA4_MOE_SPEC.model_names) class UnsupportedModelArchitectureError(ValueError): diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index c68b21341..6bf5c895a 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -139,10 +139,15 @@ def _art_flex_core_attention(config: object) -> object: ArtContextParallelCoreAttention, ) - return ArtContextParallelCoreAttention - from art.megatron.flex_attn.attention import FlexDotProductAttention + base_core_attention = ArtContextParallelCoreAttention + else: + from art.megatron.flex_attn.attention import FlexDotProductAttention - return FlexDotProductAttention + base_core_attention = FlexDotProductAttention + wrapper = getattr(config, "art_flex_core_attention_wrapper", None) + if wrapper is None: + return base_core_attention + return wrapper(config, base_core_attention) def _runtime_context_parallel_size() -> int: diff --git a/src/art/megatron/runtime_config.py b/src/art/megatron/runtime_config.py new file mode 100644 index 000000000..5f1c4bbfc --- /dev/null +++ b/src/art/megatron/runtime_config.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from ..types import MegatronRuntimeConfig, MegatronTopologyConfig + +_MEGATRON_RUNTIME_CONFIG: MegatronRuntimeConfig | None = None + + +def init_megatron_runtime_config( + config: MegatronRuntimeConfig | Mapping[str, Any] | None = None, + *, + topology: MegatronTopologyConfig | Mapping[str, int | None] | None = None, + packed_sequence_length: int | None = None, +) -> MegatronRuntimeConfig: + global _MEGATRON_RUNTIME_CONFIG + if config is None: + config = { + "topology": topology, + "packed_sequence_length": packed_sequence_length, + } + runtime_config = MegatronRuntimeConfig.model_validate(config) + if _MEGATRON_RUNTIME_CONFIG is None: + _MEGATRON_RUNTIME_CONFIG = runtime_config + elif _MEGATRON_RUNTIME_CONFIG != runtime_config: + raise ValueError( + "Megatron runtime config is already initialized with " + f"{_MEGATRON_RUNTIME_CONFIG.model_dump(mode='json')}, got " + f"{runtime_config.model_dump(mode='json')}." + ) + return _MEGATRON_RUNTIME_CONFIG + + +def get_megatron_runtime_config() -> MegatronRuntimeConfig: + if _MEGATRON_RUNTIME_CONFIG is None: + raise RuntimeError( + "Call art.init_megatron_runtime_config(...) before using MegatronBackend." + ) + return _MEGATRON_RUNTIME_CONFIG diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 884188a8d..57d68a23b 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -1,5 +1,4 @@ import asyncio -from collections.abc import Mapping from dataclasses import dataclass, field import importlib import json @@ -20,7 +19,7 @@ from ..local.checkpoints import get_last_checkpoint_dir from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch -from ..types import MegatronTopologyConfig +from ..types import MegatronRuntimeConfig, MegatronTopologyConfig from ..utils.get_model_step import get_step_from_dir from ..utils.lifecycle import ( ChildProcessSupervisor, @@ -57,6 +56,7 @@ MergedWeightTransferInitInfo, MergedWeightTransferSpec, ) +from .runtime_config import get_megatron_runtime_config from .training.sft_batches import materialize_sft_batches safetensors = importlib.import_module("safetensors") @@ -186,6 +186,9 @@ class MegatronService: config: dev.InternalModelConfig | dev.BackendModelConfig output_dir: str enable_expert_replay: bool = True + runtime_config: MegatronRuntimeConfig = field( + default_factory=get_megatron_runtime_config + ) _is_sleeping: bool = False _latest_step: int = 0 _megatron_process: asyncio.subprocess.Process | None = None @@ -348,16 +351,6 @@ def _allocate_master_port(self) -> int: sock.bind(("", 0)) return int(sock.getsockname()[1]) - @staticmethod - def _resolve_megatron_topology( - raw_topology: Mapping[str, int | None] | MegatronTopologyConfig | None, - ) -> MegatronTopologyConfig | None: - if raw_topology is None: - return None - if isinstance(raw_topology, MegatronTopologyConfig): - return raw_topology - return MegatronTopologyConfig.model_validate(raw_topology) - @staticmethod def _megatron_topology_env(topology: MegatronTopologyConfig) -> dict[str, str]: env = { @@ -654,10 +647,9 @@ async def _sync_dedicated_merged_weights( *, lora_path: str, step: int, - megatron_topology: MegatronTopologyConfig | None = None, ) -> None: self._raise_if_child_failed() - await self._ensure_megatron_running(megatron_topology=megatron_topology) + await self._ensure_megatron_running() await self._init_merged_weight_transfer() self._clear_pending_jobs() job_path, log_path = self._create_megatron_job_paths() @@ -723,17 +715,14 @@ def _validate_megatron_dependencies(self) -> None: raise RuntimeError( "Megatron dependencies are not available in the active ART environment. " "Run `setup.sh` for this worktree and build the project venv with " - "`uv sync --extra backend --extra megatron` before starting Megatron " + "`uv sync --extra megatron` before starting Megatron " "training." ) from exc - async def _ensure_megatron_running( - self, - *, - megatron_topology: MegatronTopologyConfig | None = None, - ) -> None: + async def _ensure_megatron_running(self) -> None: """Lazily start Megatron training process if not running.""" self._raise_if_child_failed() + megatron_topology = self.runtime_config.topology if self._megatron_process is not None: if self._megatron_process.returncode is None: assert self._active_megatron_topology == megatron_topology @@ -775,10 +764,9 @@ async def _ensure_megatron_running( env[MEGATRON_LORA_RANK_ENV] = str(int(rank)) if target_modules := lora_config.get("target_modules"): env[MEGATRON_LORA_TARGET_MODULES_ENV] = json.dumps(list(target_modules)) - if megatron_topology is not None: - for env_name in self._megatron_topology_env_names(): - env.pop(env_name, None) - env.update(self._megatron_topology_env(megatron_topology)) + for env_name in self._megatron_topology_env_names(): + env.pop(env_name, None) + env.update(self._megatron_topology_env(megatron_topology)) command = [ sys.executable, @@ -845,16 +833,12 @@ def _resolve_training_lora_path(self) -> str: self._ensure_lora_adapter_config(lora_path) return lora_path - async def _prepare_for_training( - self, - *, - megatron_topology: MegatronTopologyConfig | None = None, - ) -> str: + async def _prepare_for_training(self) -> str: self._raise_if_child_failed() self._validate_megatron_dependencies() # Shared-GPU Megatron must start after vLLM has released GPU memory. await self._sleep_runtime() - await self._ensure_megatron_running(megatron_topology=megatron_topology) + await self._ensure_megatron_running() lora_path = self._resolve_training_lora_path() self._clear_pending_jobs() @@ -924,9 +908,6 @@ async def start_openai_server( await self._sync_dedicated_merged_weights( lora_path=lora_path, step=self._latest_step, - megatron_topology=self._resolve_megatron_topology( - self.config.get("megatron_topology") - ), ) except BaseException: await self.aclose() @@ -950,16 +931,8 @@ async def train( "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " "MegatronService subprocess jobs must use moe_routing_replay_path." ) - megatron_topology = self._resolve_megatron_topology( - cast( - Mapping[str, int | None] | MegatronTopologyConfig | None, - _config.get( - "megatron_topology", self.config.get("megatron_topology") - ), - ) - ) if self.is_dedicated: - await self._ensure_megatron_running(megatron_topology=megatron_topology) + await self._ensure_megatron_running() lora_path = self._resolve_active_lora_path() self._clear_pending_jobs() next_step = self._latest_step + 1 @@ -1025,9 +998,7 @@ async def train( await self._reload_adapter(new_checkpoint_dir, next_step) return - lora_path = await self._prepare_for_training( - megatron_topology=megatron_topology - ) + lora_path = await self._prepare_for_training() next_step = self._latest_step + 1 staging_lora_path = self._prepare_training_lora_dir( lora_path, @@ -1081,9 +1052,7 @@ async def train_sft( raise NotImplementedError( "train_sft is not yet supported in dedicated mode" ) - lora_path = await self._prepare_for_training( - megatron_topology=config.megatron_topology - ) + lora_path = await self._prepare_for_training() next_step = self._latest_step + 1 staging_lora_path = self._prepare_training_lora_dir( lora_path, diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh index 6d3a5548c..1e9c60eb1 100755 --- a/src/art/megatron/setup.sh +++ b/src/art/megatron/setup.sh @@ -25,8 +25,7 @@ if [ "${#missing_packages[@]}" -gt 0 ]; then fi fi -# Python dependencies are declared in pyproject.toml extras. -# Megatron setup still needs the shared backend extras, but the vLLM runtime now +# Python dependencies are declared in pyproject.toml extras. The vLLM runtime # lives in its own project and venv under vllm_runtime/. script_dir="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" repo_root="$(cd -- "${script_dir}/../../.." && pwd)" @@ -35,4 +34,4 @@ uv_bin="uv" if [ -x "${HOME}/.local/bin/uv" ]; then uv_bin="${HOME}/.local/bin/uv" fi -"${uv_bin}" sync --extra backend --extra megatron --frozen --active +"${uv_bin}" sync --extra megatron --frozen --active diff --git a/src/art/megatron/shared_prefix_state.py b/src/art/megatron/shared_prefix_state.py index 7bbda4624..e6ecccfc8 100644 --- a/src/art/megatron/shared_prefix_state.py +++ b/src/art/megatron/shared_prefix_state.py @@ -56,6 +56,8 @@ def create_shared_prefix_state( parent_ids: Tensor, *, target_device: torch.device | None = None, + input_pos: Tensor | None = None, + sliding_windows: tuple[int, ...] = (), build_gdn_execution_spec: bool = False, attention_token_layout_index: TokenLayoutIndex | None = None, attention_head_dim: int | None = None, @@ -65,16 +67,31 @@ def create_shared_prefix_state( device = group_ids.device if target_device is None else torch.device(target_device) group_ids_cpu = _metadata_cpu(group_ids) parent_ids_cpu = _metadata_cpu(parent_ids) + input_pos_cpu = _metadata_cpu(input_pos) if input_pos is not None else None + block_size = _shared_prefix_block_size( + device, + attention_head_dim=attention_head_dim, + attention_value_head_dim=attention_value_head_dim, + ) block_mask = _build_sparse_shared_prefix_block_mask( group_ids_cpu=group_ids_cpu, parent_ids_cpu=parent_ids_cpu, + input_pos_cpu=input_pos_cpu, + sliding_window=None, device=device, - block_size=_shared_prefix_block_size( - device, - attention_head_dim=attention_head_dim, - attention_value_head_dim=attention_value_head_dim, - ), + block_size=block_size, ) + sliding_block_masks = { + window: _build_sparse_shared_prefix_block_mask( + group_ids_cpu=group_ids_cpu, + parent_ids_cpu=parent_ids_cpu, + input_pos_cpu=input_pos_cpu, + sliding_window=window, + device=device, + block_size=block_size, + ) + for window in tuple(dict.fromkeys(int(window) for window in sliding_windows)) + } cp_rank, cp_size, cp_group = _gdn_cp_rank_size_group() gdn_execution_spec = _build_gdn_execution_spec_once( group_ids_cpu, @@ -86,6 +103,7 @@ def create_shared_prefix_state( ) return SharedPrefixAttentionState( block_mask=block_mask, + sliding_block_masks=sliding_block_masks, group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, gdn_execution_spec=gdn_execution_spec, @@ -111,6 +129,8 @@ def _build_sparse_shared_prefix_block_mask( *, group_ids_cpu: Tensor, parent_ids_cpu: Tensor, + input_pos_cpu: Tensor | None, + sliding_window: int | None, device: torch.device, block_size: tuple[int, int], ): @@ -136,11 +156,17 @@ def _build_sparse_shared_prefix_block_mask( exact_mask=ExactMaskMetadata( q_token_indices=torch.arange(seq_len, dtype=torch.int64), k_token_indices=torch.arange(seq_len, dtype=torch.int64), - cache_key=f"identity:{seq_len}", + cache_key=( + f"identity:{seq_len}" + if sliding_window is None + else f"identity:{seq_len}:sliding:{int(sliding_window)}" + ), ), ), group_ids=group_ids_cpu[0], parent_ids=parent_ids_cpu[0], + input_pos=None if input_pos_cpu is None else input_pos_cpu[0], + sliding_window=sliding_window, device=device, ) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 2a7c96410..a1130fde2 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -782,6 +782,7 @@ def _run_megatron_job(runtime: TrainingRuntime, job: MegatronJob) -> None: _sync_merged_weights_to_vllm( runtime, job.merged_weight_transfer, + lora_path=job.lora_path, pause_generation=False, ) return @@ -793,6 +794,7 @@ def _run_megatron_job(runtime: TrainingRuntime, job: MegatronJob) -> None: _sync_merged_weights_to_vllm( runtime, job.merged_weight_transfer, + lora_path=job.lora_path, pause_generation=True, ) @@ -1623,8 +1625,13 @@ def _sync_merged_weights_to_vllm( runtime: TrainingRuntime, spec: MergedWeightTransferSpec, *, + lora_path: str, pause_generation: bool, ) -> None: + adapter_model = load_lora_tensors_for_megatron( + lora_path, + handler=runtime.model_support_handler, + ) ( runtime.merged_weight_transfer_group, runtime.merged_weight_transfer_init_info, @@ -1632,6 +1639,8 @@ def _sync_merged_weights_to_vllm( bridge=runtime.bridge, model=runtime.model, model_support_handler=runtime.model_support_handler, + adapter_model=adapter_model, + adapter_config=load_adapter_config(lora_path), rank=runtime.rank, world_size=runtime.world_size, merged_weight_transfer_group=runtime.merged_weight_transfer_group, @@ -1641,15 +1650,19 @@ def _sync_merged_weights_to_vllm( ) -def _close_merged_weight_transfer_group(runtime: TrainingRuntime) -> None: +def _close_merged_weight_transfer_group( + runtime: TrainingRuntime, *, abort: bool = False +) -> None: weight_transfer_group = runtime.merged_weight_transfer_group runtime.merged_weight_transfer_group = None runtime.merged_weight_transfer_init_info = None if weight_transfer_group is None: return - close = getattr(weight_transfer_group, "close", None) - if close is not None: - close() + shutdown = getattr(weight_transfer_group, "abort" if abort else "close", None) + if shutdown is None and abort: + shutdown = getattr(weight_transfer_group, "close", None) + if shutdown is not None: + shutdown() def _run_service_loop(runtime: TrainingRuntime) -> None: @@ -1674,6 +1687,7 @@ def after_job() -> None: runtime.optimizer = None weight_offload.after_job() + worker_error = False try: after_job() run_megatron_worker_loop( @@ -1683,8 +1697,11 @@ def after_job() -> None: before_job=before_job, after_job=after_job, ) + except BaseException: + worker_error = True + raise finally: - _close_merged_weight_transfer_group(runtime) + _close_merged_weight_transfer_group(runtime, abort=worker_error) def main() -> None: diff --git a/src/art/megatron/training/microbatches.py b/src/art/megatron/training/microbatches.py index 4decd9fe6..8bf65017f 100644 --- a/src/art/megatron/training/microbatches.py +++ b/src/art/megatron/training/microbatches.py @@ -11,6 +11,7 @@ from art.megatron.context_parallel.runtime import prepare_cp_micro from art.megatron.context_parallel.types import ( ContextParallelConfig, + CpBlockMaskVariant, DispatchedPackedTensors, ParallelTopology, PreparedMegatronBatch, @@ -72,7 +73,6 @@ def selected_tensor(value: torch.Tensor) -> torch.Tensor: }, pixel_values=[None], image_grid_thw=[None], - moe_routing_replay=None, ) @@ -86,7 +86,6 @@ def _clone_packed_tensors(inputs: PackedTensors) -> PackedTensors: }, pixel_values=[None], image_grid_thw=[None], - moe_routing_replay=None, ) @@ -244,8 +243,59 @@ def _local_trainable_token_count_tensor( return torch.tensor([local_token_total], device=device, dtype=torch.float32) +def _art_flex_sliding_windows(provider: Any) -> tuple[int, ...]: + return tuple( + dict.fromkeys( + int(window) for window in getattr(provider, "art_flex_sliding_windows", ()) + ) + ) + + +def _art_flex_cp_block_mask_variants( + provider: Any, + device: torch.device, +) -> tuple[CpBlockMaskVariant, ...]: + head_dims = getattr(provider, "art_flex_head_dims_by_window", {}) + value_head_dims = getattr(provider, "art_flex_value_head_dims_by_window", {}) + default_head_dim = getattr(provider, "kv_channels", None) + variants: list[CpBlockMaskVariant] = [] + seen: set[tuple[int | None, tuple[int, int]]] = set() + for window in (None, *_art_flex_sliding_windows(provider)): + head_dim = ( + head_dims.get(window, default_head_dim) + if isinstance(head_dims, dict) + else default_head_dim + ) + value_head_dim = ( + value_head_dims.get(window, head_dim) + if isinstance(value_head_dims, dict) + else head_dim + ) + block_size = ( + (128, 128) + if head_dim is None + else flash_sparse_block_size_for_head_dim( + head_dim=int(head_dim), + head_dim_v=int(head_dim if value_head_dim is None else value_head_dim), + device=device, + ) + ) + key = (None if window is None else int(window), block_size) + if key in seen: + continue + seen.add(key) + variants.append( + CpBlockMaskVariant( + sliding_window=None if window is None else int(window), + block_size=block_size, + ) + ) + return tuple(variants) + + def _context_parallel_config_for_provider( - provider: Any, device: torch.device + provider: Any, + device: torch.device, ) -> ContextParallelConfig: head_dim = getattr(provider, "kv_channels", None) if head_dim is None: @@ -263,6 +313,7 @@ def _causal_attention_state( seq_len: int, device: torch.device, *, + sliding_windows: tuple[int, ...] = (), build_gdn_execution_spec: bool, attention_head_dim: int | None = None, attention_value_head_dim: int | None = None, @@ -273,6 +324,8 @@ def _causal_attention_state( group_ids=group_ids, parent_ids=parent_ids, target_device=device, + input_pos=torch.arange(seq_len, dtype=torch.int64).unsqueeze(0), + sliding_windows=sliding_windows, build_gdn_execution_spec=build_gdn_execution_spec, attention_head_dim=attention_head_dim, attention_value_head_dim=attention_value_head_dim, @@ -302,6 +355,8 @@ def _prepare_dense_rl_micro( group_ids=micro["group_ids"], parent_ids=micro["parent_ids"], target_device=device, + input_pos=micro["input_pos"], + sliding_windows=_art_flex_sliding_windows(provider), build_gdn_execution_spec=bool( getattr(model_support_handler, "build_gdn_execution_spec", False) ), @@ -353,6 +408,7 @@ def _prepare_rl_cp_micro_full( getattr(model_support_handler, "build_gdn_execution_spec", False) ), trace_token_uids=trace_token_uids, + block_mask_variants=_art_flex_cp_block_mask_variants(provider, device), target_device=device, ref_logprobs=ref_logprobs, ) @@ -501,6 +557,7 @@ def _prepare_dense_sft_micro( attention_state=_causal_attention_state( seq_len, device, + sliding_windows=_art_flex_sliding_windows(provider), build_gdn_execution_spec=bool( getattr(model_support_handler, "build_gdn_execution_spec", False) ), @@ -581,6 +638,7 @@ def _prepare_sft_cp_micro_full( getattr(model_support_handler, "build_gdn_execution_spec", False) ), trace_token_uids=trace_token_uids, + block_mask_variants=_art_flex_cp_block_mask_variants(provider, device), target_device=device, ) diff --git a/src/art/megatron/weights/lora_publish.py b/src/art/megatron/weights/lora_publish.py index f4fd02a0a..f2728dfda 100644 --- a/src/art/megatron/weights/lora_publish.py +++ b/src/art/megatron/weights/lora_publish.py @@ -671,6 +671,31 @@ def _save_rank0_vllm_lora( adapter_config: dict[str, Any], output_dir: str, ) -> None: + vllm_tensors, published_config = _rank0_vllm_lora_tensors( + metadata=metadata, + tensors_by_owner_key=tensors_by_owner_key, + packed_expert_metadata=packed_expert_metadata, + packed_expert_tensors_by_owner_key=packed_expert_tensors_by_owner_key, + handler=handler, + adapter_config=adapter_config, + ) + stager = _PinnedCpuStager() + published_tensors = _stage_published_tensors(vllm_tensors, stager) + stager.finish() + save_vllm_lora_tensors(output_dir, published_tensors, published_config) + + +def _rank0_vllm_lora_tensors( + *, + metadata: list[LoraShardMeta], + tensors_by_owner_key: dict[tuple[int, str], torch.Tensor], + packed_expert_metadata: list[PackedExpertShardMeta] | None = None, + packed_expert_tensors_by_owner_key: ( + dict[tuple[int, str], torch.Tensor] | None + ) = None, + handler: Any, + adapter_config: dict[str, Any], +) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: merged_tensors = merge_sharded_adapter_entries( _entries_by_key(metadata, tensors_by_owner_key) ) @@ -685,26 +710,21 @@ def _save_rank0_vllm_lora( if key in merged_tensors: raise RuntimeError(f"Duplicate LoRA tensor after packed publish: {key}") merged_tensors[key] = tensor - vllm_tensors, published_config = handler.to_vllm_lora_tensors( + return handler.to_vllm_lora_tensors( merged_tensors, adapter_config=dict(adapter_config), ) - stager = _PinnedCpuStager() - published_tensors = _stage_published_tensors(vllm_tensors, stager) - stager.finish() - save_vllm_lora_tensors(output_dir, published_tensors, published_config) -def save_vllm_lora_from_model( +def build_vllm_lora_tensors_from_model( *, model: ModelChunks, adapter_model: dict[str, torch.Tensor], handler: Any, adapter_config: dict[str, Any], - output_dir: str, rank: int, world_size: int, -) -> None: +) -> tuple[dict[str, torch.Tensor], dict[str, Any]] | None: actual_rank, device = _rank_and_device() if _distributed_ready(): actual_world_size = torch.distributed.get_world_size() # type: ignore[possibly-missing-attribute] @@ -761,14 +781,40 @@ def save_vllm_lora_from_model( ) if rank != 0: - return + return None - _save_rank0_vllm_lora( + return _rank0_vllm_lora_tensors( metadata=all_metadata, tensors_by_owner_key=exchanged_tensors, packed_expert_metadata=all_packed_metadata, packed_expert_tensors_by_owner_key=exchanged_packed_tensors, handler=handler, adapter_config=adapter_config, - output_dir=output_dir, ) + + +def save_vllm_lora_from_model( + *, + model: ModelChunks, + adapter_model: dict[str, torch.Tensor], + handler: Any, + adapter_config: dict[str, Any], + output_dir: str, + rank: int, + world_size: int, +) -> None: + result = build_vllm_lora_tensors_from_model( + model=model, + adapter_model=adapter_model, + handler=handler, + adapter_config=adapter_config, + rank=rank, + world_size=world_size, + ) + if result is None: + return + vllm_tensors, published_config = result + stager = _PinnedCpuStager() + published_tensors = _stage_published_tensors(vllm_tensors, stager) + stager.finish() + save_vllm_lora_tensors(output_dir, published_tensors, published_config) diff --git a/src/art/megatron/weights/merged_weight_export.py b/src/art/megatron/weights/merged_weight_export.py index b09eaff08..c61a07460 100644 --- a/src/art/megatron/weights/merged_weight_export.py +++ b/src/art/megatron/weights/merged_weight_export.py @@ -13,6 +13,7 @@ MergedWeightTransferSpec, ) from art.megatron.training.model_chunks import ModelChunks, as_megatron_api_chunks +from art.megatron.weights.lora_publish import build_vllm_lora_tensors_from_model from art.megatron.weights.param_name_canonicalization import ( canonical_art_param_name, is_art_adapter_param_name, @@ -331,6 +332,8 @@ def sync_merged_weights_to_vllm( bridge: Any, model: ModelChunks, model_support_handler: Any, + adapter_model: dict[str, torch.Tensor], + adapter_config: dict[str, Any], rank: int, world_size: int, merged_weight_transfer_group: TrainerNcclCommunicator | None, @@ -350,16 +353,26 @@ def sync_merged_weights_to_vllm( merged_weight_transfer_init_info=merged_weight_transfer_init_info, spec=spec, ) - weight_export = build_merged_weight_export( - bridge=bridge, + _ = bridge + lora_result = build_vllm_lora_tensors_from_model( model=model, - model_support_handler=model_support_handler, + adapter_model=adapter_model, + handler=model_support_handler, + adapter_config=adapter_config, + rank=rank, + world_size=world_size, ) + lora_weights: list[tuple[str, torch.Tensor]] = [] + published_config: dict[str, Any] = {} + if _is_sender_rank(rank): + assert lora_result is not None + vllm_lora_tensors, published_config = lora_result + lora_weights = sorted(vllm_lora_tensors.items()) def _send_weights() -> None: assert merged_weight_transfer_group is not None trainer_send_weights( - iter_merged_vllm_weights(weight_export), + iter(lora_weights), { "group": merged_weight_transfer_group, "packed": True, @@ -369,15 +382,11 @@ def _send_weights() -> None: ) torch.cuda.synchronize() - names: list[str] = [] - dtype_names: list[str] = [] - shapes: list[list[int]] = [] - _drain_merged_vllm_weights( - weight_export, - names=names if _is_sender_rank(rank) else None, - dtype_names=dtype_names if _is_sender_rank(rank) else None, - shapes=shapes if _is_sender_rank(rank) else None, - ) + names = [name for name, _tensor in lora_weights] + dtype_names = [ + str(tensor.dtype).removeprefix("torch.") for _name, tensor in lora_weights + ] + shapes = [list(tensor.shape) for _name, tensor in lora_weights] _maybe_distributed_barrier(world_size) pause_error: BaseException | None = None @@ -410,7 +419,7 @@ def _send_weights() -> None: client.post, f"{spec.vllm_base_url}/start_weight_update", phase="start merged weight update", - json={"is_checkpoint_format": True}, + json={"is_checkpoint_format": False}, headers=_runtime_headers(spec), timeout=300.0, ) @@ -422,6 +431,8 @@ def _send_weights() -> None: phase="update merged weights", json={ "update_info": { + "art_weight_update_kind": "lora_delta", + "art_lora_config": published_config, "names": names, "dtype_names": dtype_names, "shapes": shapes, @@ -483,7 +494,6 @@ def _send_weights() -> None: phase="pause generation", error=None, ) - _drain_merged_vllm_weights(weight_export) _sync_rank_zero_status( rank=rank, world_size=world_size, diff --git a/src/art/model.py b/src/art/model.py index 06f1f88e1..597aa30a5 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -26,6 +26,7 @@ summarize_trajectory_groups, ) from .preprocessing.moe_routing import attach_moe_routing_metadata_to_choice +from .preprocessing.vllm_tokens import attach_vllm_token_metadata_to_choice from .trajectories import Trajectory, TrajectoryGroup from .types import SFTMetricLoggingConfig, TrainSFTConfig from .utils.trajectory_logging import write_trajectory_groups_parquet @@ -60,13 +61,18 @@ def _merge_extra_body_defaults( return merged -def _attach_response_moe_routing_metadata(response: Any) -> None: +def _attach_response_art_metadata(response: Any) -> None: choices = getattr(response, "choices", None) model_dump = getattr(response, "model_dump", None) if not choices or not callable(model_dump): return response_payload = model_dump(mode="python") for choice_index, choice in enumerate(choices): + attach_vllm_token_metadata_to_choice( + choice=choice, + response_payload=response_payload, + choice_index=choice_index, + ) attach_moe_routing_metadata_to_choice( choice=choice, response_payload=response_payload, @@ -92,7 +98,7 @@ async def create(self, *args: Any, **kwargs: Any) -> Any: kwargs.get("extra_body"), ) response = await self._completions.create(*args, **kwargs) - _attach_response_moe_routing_metadata(response) + _attach_response_art_metadata(response) self._record_costs(response) return response @@ -385,12 +391,22 @@ def openai_client( def _default_chat_completion_extra_body(self) -> dict[str, Any] | None: internal_config = getattr(self, "_internal_config", None) - if internal_config is None: + if internal_config is None and not self.trainable: return None - chat_template_kwargs = internal_config.get("chat_template_kwargs") - if chat_template_kwargs is None: + body: dict[str, Any] = {} + if self.trainable: + body["return_token_ids"] = True + body["return_tokens_as_token_ids"] = True + chat_template_kwargs = ( + internal_config.get("chat_template_kwargs") + if internal_config is not None + else None + ) + if chat_template_kwargs is not None: + body["chat_template_kwargs"] = dict(chat_template_kwargs) + if not body: return None - return {"chat_template_kwargs": dict(chat_template_kwargs)} + return body def litellm_completion_params(self, step: int | None = None) -> dict: """Return the parameters that should be sent to litellm.completion. diff --git a/src/art/openai.py b/src/art/openai.py index 8e70cdcf1..ab716a5e1 100644 --- a/src/art/openai.py +++ b/src/art/openai.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, cast from openai import AsyncStream, Stream from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs @@ -82,7 +82,21 @@ def init_chat_completion(chunk: ChatCompletionChunk) -> ChatCompletion: def update_chat_completion( chat_completion: ChatCompletion, chunk: ChatCompletionChunk ) -> None: + chat_completion_extra = cast(dict[str, Any], chat_completion.model_extra) + prompt_token_ids = getattr(chunk, "prompt_token_ids", None) + if prompt_token_ids is not None: + chat_completion_extra["prompt_token_ids"] = prompt_token_ids + completion_prompt_token_ids = chat_completion_extra.get("prompt_token_ids") for choice, chunk_choice in zip(chat_completion.choices, chunk.choices): + choice_extra = cast(dict[str, Any], choice.model_extra) + if completion_prompt_token_ids is not None: + choice_extra["prompt_token_ids"] = completion_prompt_token_ids + token_ids = getattr(chunk_choice, "token_ids", None) + if token_ids: + choice_extra["token_ids"] = [ + *choice_extra.get("token_ids", []), + *token_ids, + ] choice.finish_reason = chunk_choice.finish_reason or "stop" if chunk_choice.logprobs: if choice.logprobs is None: diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index da9aa921a..fa9cae017 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -91,10 +91,8 @@ def __init__( loss_fn_config: dict | None = None, normalize_advantages: bool = True, adam_params: object | None = None, - packed_sequence_length: int | None = None, kl_penalty_coef: float = 0.0, kl_penalty_step_lag: int | None = None, - megatron_topology: art.MegatronTopologyConfig | None = None, max_steps: int | None = None, # Discard handling discard_queue_multiplier: int = 100, @@ -152,10 +150,8 @@ def __init__( self.loss_fn_config = loss_fn_config self.normalize_advantages = normalize_advantages self.adam_params = adam_params - self.packed_sequence_length = packed_sequence_length self.kl_penalty_coef = kl_penalty_coef self.kl_penalty_step_lag = kl_penalty_step_lag - self.megatron_topology = megatron_topology self.max_steps = max_steps self._status_log_interval_seconds = log_interval_seconds self.eval_every_n_steps = eval_every_n_steps @@ -583,8 +579,6 @@ async def _training_stage(self) -> None: "save_checkpoint": should_checkpoint, "adam_params": self.adam_params, } - if self.packed_sequence_length is not None: - train_kwargs["packed_sequence_length"] = self.packed_sequence_length if self.kl_penalty_coef > 0.0: kl_penalty_reference_step = self._kl_penalty_reference_step( current_step @@ -594,8 +588,6 @@ async def _training_stage(self) -> None: train_kwargs["kl_penalty_reference_step"] = ( kl_penalty_reference_step ) - if self.megatron_topology is not None: - train_kwargs["megatron_topology"] = self.megatron_topology result = await self.backend.train( self.model, batch, diff --git a/src/art/preprocessing/moe_routing.py b/src/art/preprocessing/moe_routing.py index f62cb9455..e4a31b307 100644 --- a/src/art/preprocessing/moe_routing.py +++ b/src/art/preprocessing/moe_routing.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import io from typing import Any from openai.types.chat.chat_completion import Choice @@ -171,8 +173,11 @@ def align_choice_routes_to_tokenized_result( stats.choices_with_routing += 1 prompt_token_ids = _normalize_token_ids(metadata.get(PROMPT_TOKEN_IDS_KEY)) completion_token_ids = _completion_token_ids(metadata) - prompt_routes = _prompt_routes(metadata) - completion_routes = _completion_routes(metadata) + prompt_routes, completion_routes = _choice_routes( + metadata, + prompt_token_count=len(prompt_token_ids), + completion_token_count=len(completion_token_ids), + ) expected_prompt_ids = token_ids[:offset] expected_completion_ids = token_ids[offset : offset + token_length] if prompt_token_ids != expected_prompt_ids: @@ -253,6 +258,8 @@ def _normalize_token_ids(raw: Any) -> list[int]: def _normalize_routes(raw: Any, *, field_name: str) -> list[TokenRoute]: + if isinstance(raw, str): + raw = _decode_vllm_routed_experts(raw, field_name=field_name) if raw is None: raise RuntimeError(f"Missing {field_name}") if not isinstance(raw, list): @@ -271,6 +278,20 @@ def _normalize_routes(raw: Any, *, field_name: str) -> list[TokenRoute]: return routes +def _decode_vllm_routed_experts(raw: str, *, field_name: str) -> list[Any]: + import numpy as np + + try: + array = np.load(io.BytesIO(base64.b64decode(raw)), allow_pickle=False) + except Exception as exc: + raise RuntimeError(f"Failed to decode {field_name} as base64 .npy") from exc + if array.ndim != 3: + raise RuntimeError( + f"Expected {field_name} array with rank 3, got shape {array.shape}" + ) + return array.tolist() + + def _validate_route_shape(route: TokenRoute) -> None: if not route: raise RuntimeError("MoE token route cannot have zero layers") @@ -288,11 +309,35 @@ def _completion_token_ids(metadata: dict[str, Any]) -> list[int]: raise RuntimeError("Missing routed completion token ids") -def _prompt_routes(metadata: dict[str, Any]) -> list[TokenRoute]: - return _normalize_routes( - metadata.get(PROMPT_ROUTED_EXPERTS_KEY), - field_name=PROMPT_ROUTED_EXPERTS_KEY, +def _choice_routes( + metadata: dict[str, Any], + *, + prompt_token_count: int, + completion_token_count: int, +) -> tuple[list[TokenRoute], list[TokenRoute]]: + if PROMPT_ROUTED_EXPERTS_KEY in metadata: + return ( + _normalize_routes( + metadata.get(PROMPT_ROUTED_EXPERTS_KEY), + field_name=PROMPT_ROUTED_EXPERTS_KEY, + ), + _completion_routes(metadata), + ) + + routes = _normalize_routes( + metadata.get(ROUTED_EXPERTS_KEY), + field_name=ROUTED_EXPERTS_KEY, ) + expected_lengths = { + prompt_token_count + completion_token_count, + prompt_token_count + max(completion_token_count - 1, 0), + } + if len(routes) not in expected_lengths: + raise RuntimeError( + "routed_experts length does not match prompt/completion token ids: " + f"{len(routes)} not in {sorted(expected_lengths)}" + ) + return routes[:prompt_token_count], routes[prompt_token_count:] def _completion_routes(metadata: dict[str, Any]) -> list[TokenRoute]: diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 051863bb0..b148cbf7e 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Generator, Literal, cast from openai.types.chat.chat_completion import Choice -from PIL import Image import torch from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase @@ -29,6 +28,7 @@ align_choice_routes_to_tokenized_result, ) from .response_masking import response_only_labels, token_ids_for_template_part +from .vllm_tokens import choice_vllm_token_metadata ChatTemplateTool = dict[Any, Any] | Callable[..., Any] ChatTemplateToolSchemaFormat = Literal["default", "vllm_openai"] @@ -246,6 +246,183 @@ def _apply_chat_template_token_ids( return cast(list[int], output) +def _choice_logprobs( + choice: Choice, + *, + token_count: int, + allow_training_without_logprobs: bool, +) -> tuple[list[float], list[Any]]: + if choice.logprobs is None: + if allow_training_without_logprobs: + return [float("nan")] * token_count, [] + raise RuntimeError("Trainable vLLM Choice is missing logprobs") + token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] + if len(token_logprobs) != token_count: + raise RuntimeError( + "Choice logprob length does not match vLLM completion token ids: " + f"{len(token_logprobs)} != {token_count}" + ) + return [float(token_logprob.logprob) for token_logprob in token_logprobs], list( + token_logprobs + ) + + +def _choice_extra_logprobs( + *, + token_count: int, + choice_offsets: list[int], + choice_token_logprobs: list[list[Any]], +) -> dict[str, list[float]]: + extra_logprobs: dict[str, list[float]] = {} + for start, token_logprobs in zip(choice_offsets, choice_token_logprobs): + for i, token_logprob in enumerate(token_logprobs): + token_extra_logprobs = (token_logprob.model_extra or {}).get( + "extra_logprobs" + ) + if not isinstance(token_extra_logprobs, dict): + continue + for key, value in token_extra_logprobs.items(): + extra_logprobs.setdefault(key, [float("nan")] * token_count)[ + start + i + ] = float("nan") if value is None else float(value) + return extra_logprobs + + +def _tokenized_result_from_vllm_choices( + *, + tokenizer: PreTrainedTokenizerBase, + token_ids: list[int], + assistant_mask: list[int], + logprobs: list[float], + choices: list[Choice], + choice_offsets: list[int], + choice_token_lengths: list[int], + choice_token_logprobs: list[list[Any]], + advantage: float, + trajectory: Trajectory, +) -> TokenizedResult: + moe_routed_experts, moe_routing_alignment_stats = ( + align_choice_routes_to_tokenized_result( + token_ids=token_ids, + choices=choices, + choice_offsets=choice_offsets, + choice_token_lengths=choice_token_lengths, + ) + ) + return TokenizedResult( + advantage=advantage, + chat="", + token_ids=token_ids, + input_pos=list(range(len(token_ids))), + assistant_mask=assistant_mask, + logprobs=logprobs, + pixel_values=None, + image_grid_thw=None, + trajectory=trajectory, + choice_offsets=choice_offsets, + extra_logprobs=_choice_extra_logprobs( + token_count=len(token_ids), + choice_offsets=choice_offsets, + choice_token_logprobs=choice_token_logprobs, + ), + moe_routed_experts=moe_routed_experts, + moe_routing_alignment_stats=moe_routing_alignment_stats, + _tokenizer=tokenizer, + ) + + +def tokenize_vllm_trajectory_histories( + *, + tokenizer: PreTrainedTokenizerBase, + histories: list[History], + advantage: float, + allow_training_without_logprobs: bool, + trajectory: Trajectory, +) -> list[TokenizedResult]: + results: list[TokenizedResult] = [] + token_ids: list[int] = [] + assistant_mask: list[int] = [] + logprobs: list[float] = [] + choices: list[Choice] = [] + choice_offsets: list[int] = [] + choice_token_lengths: list[int] = [] + choice_token_logprobs: list[list[Any]] = [] + + def flush() -> None: + nonlocal token_ids, assistant_mask, logprobs, choices + nonlocal choice_offsets, choice_token_lengths, choice_token_logprobs + if not choices: + return + results.append( + _tokenized_result_from_vllm_choices( + tokenizer=tokenizer, + token_ids=token_ids, + assistant_mask=assistant_mask, + logprobs=logprobs, + choices=choices, + choice_offsets=choice_offsets, + choice_token_lengths=choice_token_lengths, + choice_token_logprobs=choice_token_logprobs, + advantage=advantage, + trajectory=trajectory, + ) + ) + token_ids = [] + assistant_mask = [] + logprobs = [] + choices = [] + choice_offsets = [] + choice_token_lengths = [] + choice_token_logprobs = [] + + for history in histories: + for choice in ( + item + for item in history.messages_and_choices + if isinstance(item, Choice) + and (item.logprobs is not None or allow_training_without_logprobs) + ): + metadata = choice_vllm_token_metadata(choice) + if metadata is None: + raise RuntimeError( + "Trainable Choice is missing vLLM prompt_token_ids/token_ids. " + "Use a vLLM endpoint with return_token_ids enabled." + ) + prompt_token_ids, completion_token_ids = metadata + completion_logprobs, token_logprobs = _choice_logprobs( + choice, + token_count=len(completion_token_ids), + allow_training_without_logprobs=allow_training_without_logprobs, + ) + if not token_ids: + token_ids.extend(prompt_token_ids) + assistant_mask.extend([0] * len(prompt_token_ids)) + logprobs.extend([float("nan")] * len(prompt_token_ids)) + elif ( + len(prompt_token_ids) >= len(token_ids) + and prompt_token_ids[: len(token_ids)] == token_ids + ): + suffix = prompt_token_ids[len(token_ids) :] + token_ids.extend(suffix) + assistant_mask.extend([0] * len(suffix)) + logprobs.extend([float("nan")] * len(suffix)) + else: + flush() + token_ids.extend(prompt_token_ids) + assistant_mask.extend([0] * len(prompt_token_ids)) + logprobs.extend([float("nan")] * len(prompt_token_ids)) + + choice_offsets.append(len(token_ids)) + choice_token_lengths.append(len(completion_token_ids)) + choice_token_logprobs.append(token_logprobs) + choices.append(choice) + token_ids.extend(completion_token_ids) + assistant_mask.extend([1] * len(completion_token_ids)) + logprobs.extend(completion_logprobs) + flush() + return results + + def tokenize_trajectory_groups( tokenizer: "PreTrainedTokenizerBase", trajectory_groups: list[TrajectoryGroup], @@ -274,25 +451,19 @@ def tokenize_trajectory_groups( advantage /= reward_std + 1e-6 if advantage == 0 and drop_zero_advantage_trajectories: continue - trajectory_results: list[TokenizedResult] = [] - for history in [ - History( - messages_and_choices=trajectory.messages_and_choices, - tools=trajectory.tools, - ), - *trajectory.additional_histories, - ]: - if result := tokenize_trajectory( - tokenizer, - image_processor, - history, - advantage, - allow_training_without_logprobs, - trajectory, - chat_template_kwargs=chat_template_kwargs, - chat_template_tool_schema_format=chat_template_tool_schema_format, - ): - trajectory_results.append(result) + trajectory_results = tokenize_vllm_trajectory_histories( + tokenizer=tokenizer, + histories=[ + History( + messages_and_choices=trajectory.messages_and_choices, + tools=trajectory.tools, + ), + *trajectory.additional_histories, + ], + advantage=advantage, + allow_training_without_logprobs=allow_training_without_logprobs, + trajectory=trajectory, + ) weight = 1 / ( sum(sum(result.assistant_mask) for result in trajectory_results) + 1e-6 ) @@ -346,247 +517,22 @@ def tokenize_trajectory( """ Tokenizes a trajectory and returns a TokenizedResult. """ - # Find the index of the last assistant message - last_assistant_index = -1 - for i, message in enumerate(history.messages_and_choices): - if ( - isinstance(message, dict) - and message["role"] == "assistant" - and allow_training_without_logprobs - ): - last_assistant_index = i - elif isinstance(message, Choice) and ( - message.logprobs or allow_training_without_logprobs - ): - last_assistant_index = i - # If there are no trainable assistant messages, return None - if last_assistant_index == -1: - return None - messages_and_choices = history.messages_and_choices[: last_assistant_index + 1] - messages = _messages_for_chat_template( - tokenizer, - messages_and_choices, - final_trainable_choice_index=( - len(messages_and_choices) - 1 - if isinstance(messages_and_choices[-1], Choice) - and messages_and_choices[-1].logprobs is not None - else None - ), - ) - tools = _normalize_tools_for_chat_template( - history.tools, - tool_schema_format=chat_template_tool_schema_format, - ) - template_kwargs = _chat_template_kwargs(tokenizer, chat_template_kwargs) - chat = cast( - str, - cast(Any, tokenizer).apply_chat_template( - messages, - tools=tools, - continue_final_message=False, - tokenize=False, - **template_kwargs, - ), - ) - original_token_ids = _apply_chat_template_token_ids( - tokenizer, - messages, - tools=tools, - continue_final_message=False, - **template_kwargs, - ) - sentinel_token_id = max(set(range(tokenizer.vocab_size)) - set(original_token_ids)) - sentinel_token = tokenizer.decode(sentinel_token_id) - token_template_messages: list[dict[str, Any]] = [] - for original, message in zip(messages_and_choices, messages): - trainable_assistant = ( - not isinstance(original, dict) and original.logprobs is not None - ) or ( - allow_training_without_logprobs - and isinstance(original, dict) - and original.get("role") == "assistant" - ) - if trainable_assistant: - token_template_messages.append( - { - "role": "assistant", - "content": sentinel_token, - **( - {"tool_calls": message.get("tool_calls")} - if message.get("tool_calls") - else {} - ), - } - ) - else: - token_template_messages.append(cast(dict[str, Any], message)) - token_ids = _apply_chat_template_token_ids( - tokenizer, - token_template_messages, - tools=tools, - continue_final_message=True, - **template_kwargs, - ) - assistant_mask: list[int] = [0] * len(token_ids) - logprobs = [float("nan")] * len(token_ids) - choice_offsets, choice_token_logprobs = [], [] - trainable_choices: list[Choice] = [] - - for message in messages_and_choices: - if isinstance(message, dict): - if message["role"] != "assistant": - continue - if not allow_training_without_logprobs: - continue - elif message.logprobs is None and not allow_training_without_logprobs: # ty:ignore[possibly-missing-attribute] - continue - start = token_ids.index(sentinel_token_id) - end = start + 1 - try: - end_token_id = token_ids[end] - except IndexError: - end_token_id = None - if isinstance(message, dict): - if message.get("tool_calls"): - raise ValueError( - "Assistant message has tool_calls but is being tokenized " - "via tokenizer.encode(content). This path ignores tool calls." - ) - content = message.get("content") - assert isinstance(content, str), ( - "Trajectories must have a 'content' field of type str" - ) - content_token_ids = tokenizer.encode( - content, - add_special_tokens=False, - ) - token_ids[start:end] = content_token_ids - logprobs[start:end] = [float("nan")] * len(content_token_ids) - assistant_mask[start:end] = [1] * len(content_token_ids) - else: - choice = cast(Choice, message) - assert choice.logprobs or allow_training_without_logprobs, ( # ty:ignore[possibly-missing-attribute] - "Chat completion choices must have logprobs" - ) - if not choice.logprobs: # ty:ignore[possibly-missing-attribute] - continue - token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] # ty:ignore[possibly-missing-attribute] - if token_logprobs and ( - bytes(token_logprobs[0].bytes or []).decode("utf-8") - == "" - == tokenizer.decode(token_ids[start - 4]) - ): - start -= 4 - choice_offsets.append(start) - choice_token_logprobs.append(token_logprobs) - trainable_choices.append(choice) - try: - token_ids[start:end] = ( - int(token_logprob.token.split(":")[1]) - for token_logprob in token_logprobs - ) - except (IndexError, ValueError): - token_ids[start:end] = [ # type: ignore[assignment] - token_id if token_id is not None else tokenizer.eos_token_id - for token_id in cast( - list[int], - tokenizer.convert_tokens_to_ids( - [ - token_logprob.token or tokenizer.eos_token - for token_logprob in token_logprobs - ] # type: ignore[arg-type] - ), - ) - ] - logprobs[start:end] = ( - token_logprob.logprob for token_logprob in token_logprobs - ) - assistant_mask[start:end] = [1] * len(token_logprobs) - if token_ids[start + len(token_logprobs) - 1] == end_token_id: - token_ids.pop(start + len(token_logprobs)) - logprobs.pop(start + len(token_logprobs)) - assistant_mask.pop(start + len(token_logprobs)) - extra_logprobs: dict[str, list[float]] = {} - for start, token_logprobs in zip(choice_offsets, choice_token_logprobs): - for i, token_logprob in enumerate(token_logprobs): - token_extra_logprobs = (token_logprob.model_extra or {}).get( - "extra_logprobs" - ) - if not isinstance(token_extra_logprobs, dict): - continue - for key, value in token_extra_logprobs.items(): - extra_logprobs.setdefault(key, [float("nan")] * len(token_ids))[ - start + i - ] = float("nan") if value is None else float(value) - if image_processor: - images: list[Image.Image] = [] - for message in messages_and_choices: - if ( - isinstance(message, dict) - and message["role"] == "user" - and isinstance(message["content"], (list, tuple)) - ): - for content in message["content"]: - if content["type"] == "image_url": - image_url = content["image_url"]["url"].removeprefix("file://") - images.append(Image.open(image_url)) - image_token_id = cast( - int, - getattr(image_processor, "image_token_id", None) - or tokenizer.convert_tokens_to_ids( - getattr(image_processor, "image_token", "<|image_pad|>") - ), - ) - if images: - result = image_processor(images=images) - offset = 0 - for num_image_tokens in ( - image_grid_thw.prod().item() - // (getattr(image_processor, "merge_size", 1) ** 2) - for image_grid_thw in result["image_grid_thw"] - ): - start = token_ids.index(image_token_id, offset) - offset = start + num_image_tokens - end = start + 1 - token_ids[start:end] = [image_token_id] * num_image_tokens - logprobs[start:end] = [float("nan")] * num_image_tokens - assistant_mask[start:end] = [0] * num_image_tokens - for values in extra_logprobs.values(): - values[start:end] = [float("nan")] * num_image_tokens - pixel_values = result["pixel_values"] - image_grid_thw = result["image_grid_thw"] - else: - pixel_values = None - image_grid_thw = None - else: - pixel_values = None - image_grid_thw = None - moe_routed_experts, moe_routing_alignment_stats = ( - align_choice_routes_to_tokenized_result( - token_ids=token_ids, - choices=trainable_choices, - choice_offsets=choice_offsets, - choice_token_lengths=[ - len(token_logprobs) for token_logprobs in choice_token_logprobs - ], - ) - ) - return TokenizedResult( + del image_processor, chat_template_kwargs, chat_template_tool_schema_format + results = tokenize_vllm_trajectory_histories( + tokenizer=tokenizer, + histories=[history], advantage=advantage, - chat=chat, - token_ids=token_ids, - input_pos=list(range(len(token_ids))), - assistant_mask=assistant_mask, - logprobs=logprobs, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, + allow_training_without_logprobs=allow_training_without_logprobs, trajectory=trajectory, - choice_offsets=choice_offsets, - extra_logprobs=extra_logprobs, - moe_routed_experts=moe_routed_experts, - moe_routing_alignment_stats=moe_routing_alignment_stats, - _tokenizer=tokenizer, ) + if not results: + return None + if len(results) > 1: + raise RuntimeError( + "History produced multiple non-append-only vLLM token sequences; " + "use tokenize_vllm_trajectory_histories to preserve split histories." + ) + return results[0] def tokenize_sft_batch( diff --git a/src/art/preprocessing/vllm_tokens.py b/src/art/preprocessing/vllm_tokens.py new file mode 100644 index 000000000..1a749e9d7 --- /dev/null +++ b/src/art/preprocessing/vllm_tokens.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any, cast + +from openai.types.chat.chat_completion import Choice + + +def _normalize_token_ids(raw: Any, *, field_name: str) -> list[int]: + if raw is None: + raise RuntimeError(f"Missing {field_name}") + if not isinstance(raw, list): + raise RuntimeError(f"Expected {field_name} list, got {type(raw)}") + return [int(token_id) for token_id in raw] + + +def attach_vllm_token_metadata_to_choice( + *, + choice: Choice, + response_payload: dict[str, Any], + choice_index: int = 0, +) -> None: + prompt_token_ids = response_payload.get("prompt_token_ids") + raw_choices = response_payload.get("choices") + if not isinstance(raw_choices, list) or choice_index >= len(raw_choices): + return + raw_choice = raw_choices[choice_index] + if not isinstance(raw_choice, dict): + return + completion_token_ids = raw_choice.get("token_ids") + if prompt_token_ids is None or completion_token_ids is None: + return + extra = cast(dict[str, Any], choice.model_extra) + extra["prompt_token_ids"] = _normalize_token_ids( + prompt_token_ids, + field_name="prompt_token_ids", + ) + extra["token_ids"] = _normalize_token_ids( + completion_token_ids, + field_name="token_ids", + ) + + +def choice_vllm_token_metadata(choice: Choice) -> tuple[list[int], list[int]] | None: + extra = choice.model_extra or {} + if "prompt_token_ids" not in extra or "token_ids" not in extra: + return None + return ( + _normalize_token_ids( + extra.get("prompt_token_ids"), + field_name="prompt_token_ids", + ), + _normalize_token_ids( + extra.get("token_ids"), + field_name="token_ids", + ), + ) diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 9c4f21f9a..7abe85865 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -561,7 +561,7 @@ async def prompt_tokens( **chat_template_kwargs, ) if isinstance(encoding, BatchEncoding): - return encoding.input_ids + return list(cast(list[int], encoding.input_ids)) else: return encoding # type: ignore diff --git a/src/art/transformers/patches.py b/src/art/transformers/patches.py index 8d0bb9ec7..6eb28b007 100644 --- a/src/art/transformers/patches.py +++ b/src/art/transformers/patches.py @@ -13,23 +13,31 @@ def _patched_preprocess_mask_arguments( config: PretrainedConfig, - input_embeds: torch.Tensor, + inputs_embeds: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], - cache_position: torch.Tensor, past_key_values: Optional[Cache], position_ids: Optional[torch.Tensor], layer_idx: Optional[int], -) -> tuple[bool, Optional[Union[torch.Tensor, "BlockMask"]], int, int]: + encoder_hidden_states: Optional[torch.Tensor] = None, +) -> tuple[ + bool, + Optional[Union[torch.Tensor, "BlockMask"]], + Optional[torch.Tensor], + int, + int, + int, + int, +]: if position_ids is not None and len(position_ids.shape) == 3: position_ids = position_ids[0] return _preprocess_mask_arguments( config, - input_embeds, + inputs_embeds, attention_mask, - cache_position, past_key_values, position_ids, layer_idx, + encoder_hidden_states, ) diff --git a/src/art/types.py b/src/art/types.py index 02d75f57d..c9f47a4ea 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -39,10 +39,16 @@ class MegatronTopologyConfig(pydantic.BaseModel): etp: int = pydantic.Field(default=1, ge=1) +class MegatronRuntimeConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + topology: MegatronTopologyConfig + packed_sequence_length: int = pydantic.Field(ge=1) + + class TrainSFTConfig(pydantic.BaseModel): learning_rate: float | list[float] = 5e-5 # Single value or per-batch list batch_size: int | Literal["auto"] = "auto" - megatron_topology: MegatronTopologyConfig | None = None class SFTMetricLoggingConfig(TypedDict, total=False): diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index e49b9a5de..6a9a50872 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -10,7 +10,7 @@ from art.dev import TrainSFTConfig as DevTrainSFTConfig from art.model import TrainableModel from art.trajectories import Trajectory - from art.types import MegatronTopologyConfig, TrainSFTConfig + from art.types import TrainSFTConfig class SFTChunk(NamedTuple): @@ -349,7 +349,6 @@ async def train_sft_from_file( warmup_ratio: float = 0.1, initial_step: int = 0, final_step: int | None = None, - megatron_topology: "MegatronTopologyConfig | None" = None, _config: "DevTrainSFTConfig | None" = None, verbose: bool = False, shuffle_buffer_size: int = 10000, @@ -373,7 +372,6 @@ async def train_sft_from_file( initial_step: Starting step for resuming training. Default: 0 final_step: Ending step (exclusive). If None, trains to end of dataset. Useful for breaking training into segments with benchmarks in between. - megatron_topology: Parallel topology for Megatron SFT training. _config: Experimental configuration. Use at your own risk. verbose: Whether to print verbose output. Default: False shuffle_buffer_size: Size of shuffle buffer. Default: 10000. @@ -446,7 +444,6 @@ async def train_sft_from_file( config = TrainSFTConfig( learning_rate=learning_rates, batch_size=batch_size, - megatron_topology=megatron_topology, ) await model.train_sft( diff --git a/src/art/weight_transfer/nccl.py b/src/art/weight_transfer/nccl.py index cecd56731..eb7adafb5 100644 --- a/src/art/weight_transfer/nccl.py +++ b/src/art/weight_transfer/nccl.py @@ -221,20 +221,32 @@ def __init__( listen_fd = listen_socket.fileno() self.rank = rank self.world_size = world_size - self.socket = listen_socket - self.store = TCPStore( - host_name=host, - port=port, - world_size=world_size, - is_master=launch_server, - timeout=timedelta(seconds=store_timeout), - use_libuv=False, - master_listen_fd=listen_fd, - ) + self.store: TCPStore | None = None + try: + self.store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=store_timeout), + use_libuv=False, + master_listen_fd=listen_fd, + ) + if listen_socket is not None: + # TCPStore owns master_listen_fd after construction. Detach the + # Python socket so its close/finalizer cannot invalidate the + # store's listening fd while the bootstrap server is alive. + listen_socket.detach() + listen_socket = None + finally: + if listen_socket is not None: + listen_socket.close() self._broadcast_send_counter = 0 self._broadcast_recv_counter = {value: 0 for value in range(world_size)} def broadcast_obj(self, obj: Any | None, *, src: int) -> Any: + if self.store is None: + raise RuntimeError("NCCL bootstrap group is closed") if self.rank == src: key = f"broadcast_from/{src}/{self._broadcast_send_counter}" self.store.set(key, cast(Any, pickle.dumps(obj))) @@ -246,9 +258,7 @@ def broadcast_obj(self, obj: Any | None, *, src: int) -> Any: return received def close(self) -> None: - if self.socket is not None: - self.socket.close() - self.socket = None + self.store = None def _canonical_cuda_device(device: int | torch.device) -> torch.device: @@ -282,18 +292,28 @@ def __init__( self.rank = rank self.world_size = world_size self._nccl = _NcclLibrary(nccl_so_path) + self._comm = None unique_id_bytes = ( _nccl_unique_id_to_bytes(self._nccl.get_unique_id()) if rank == 0 else None ) - unique_id = _nccl_unique_id_from_bytes( - bootstrap_group.broadcast_obj(unique_id_bytes, src=0) - ) - with torch.cuda.device(self.device): - self._comm = self._nccl.init_rank(world_size, unique_id, rank) - stream = torch.cuda.current_stream(self.device) - warmup = torch.zeros(1, device=self.device) - self.all_reduce(warmup, stream=stream) - stream.synchronize() + try: + unique_id = _nccl_unique_id_from_bytes( + bootstrap_group.broadcast_obj(unique_id_bytes, src=0) + ) + with torch.cuda.device(self.device): + self._comm = self._nccl.init_rank(world_size, unique_id, rank) + stream = torch.cuda.current_stream(self.device) + warmup = torch.zeros(1, device=self.device) + self.all_reduce(warmup, stream=stream) + stream.synchronize() + finally: + self._close_bootstrap_group() + + def _close_bootstrap_group(self) -> None: + bootstrap_group = self._bootstrap_group + self._bootstrap_group = None + if bootstrap_group is not None: + bootstrap_group.close() def _require_comm(self) -> Any: if self._comm is None: @@ -321,7 +341,7 @@ def close(self) -> None: try: self._nccl.destroy_comm(comm) finally: - self._bootstrap_group.close() + self._close_bootstrap_group() def abort(self) -> None: comm = self._comm @@ -331,7 +351,7 @@ def abort(self) -> None: try: self._nccl.abort_comm(comm) finally: - self._bootstrap_group.close() + self._close_bootstrap_group() def all_reduce( self, diff --git a/tests/integration/megatron/lora/merged_vllm_serving.py b/tests/integration/megatron/lora/merged_vllm_serving.py index 2d63c996e..7a9a1fd8a 100644 --- a/tests/integration/megatron/lora/merged_vllm_serving.py +++ b/tests/integration/megatron/lora/merged_vllm_serving.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field import torch +import art from art import dev from art.megatron.service import MegatronService @@ -65,6 +66,19 @@ def _resolve_dedicated_gpu_ids() -> tuple[list[int], list[int]]: return [0], [1] +def _init_runtime_config(case_config: OracleCaseConfig) -> None: + art.init_megatron_runtime_config( + topology=art.MegatronTopologyConfig( + tp=ORACLE_TOPOLOGY.tp, + cp=ORACLE_TOPOLOGY.cp, + ep=ORACLE_TOPOLOGY.ep, + pp=ORACLE_TOPOLOGY.pp, + etp=ORACLE_TOPOLOGY.etp, + ), + packed_sequence_length=case_config.packed_tensors.sequence_length, + ) + + async def _run_merged_vllm_serving( case_config: OracleCaseConfig, ) -> MergedVllmServingReport: @@ -81,6 +95,7 @@ async def _run_merged_vllm_serving( ) dev.validate_dedicated_config(internal_config) with provider_topology_env(ORACLE_TOPOLOGY): + _init_runtime_config(case_config) service = MegatronService( model_name=service_name, base_model=case_config.base_model, diff --git a/tests/integration/megatron/lora/native_vllm_lora.py b/tests/integration/megatron/lora/native_vllm_lora.py index e28597bbc..a5689275b 100644 --- a/tests/integration/megatron/lora/native_vllm_lora.py +++ b/tests/integration/megatron/lora/native_vllm_lora.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field import torch +import art from art import dev from art.megatron.service import MegatronService from art.utils.output_dirs import get_step_checkpoint_dir @@ -105,6 +106,19 @@ def _copy_adapter_checkpoint(source_dir: str, dest_dir: str) -> None: shutil.copy(Path(source_dir) / filename, Path(dest_dir) / filename) +def _init_runtime_config(case_config: OracleCaseConfig) -> None: + art.init_megatron_runtime_config( + topology=art.MegatronTopologyConfig( + tp=ORACLE_TOPOLOGY.tp, + cp=ORACLE_TOPOLOGY.cp, + ep=ORACLE_TOPOLOGY.ep, + pp=ORACLE_TOPOLOGY.pp, + etp=ORACLE_TOPOLOGY.etp, + ), + packed_sequence_length=case_config.packed_tensors.sequence_length, + ) + + async def _run_native_vllm_lora( case_config: OracleCaseConfig, ) -> NativeVllmLoraServingReport: @@ -122,6 +136,7 @@ async def _run_native_vllm_lora( ) dev.validate_dedicated_config(internal_config) with provider_topology_env(ORACLE_TOPOLOGY): + _init_runtime_config(case_config) service = MegatronService( model_name=service_name, base_model=case_config.base_model, diff --git a/tests/integration/megatron/lora/test_merged_weight_export.py b/tests/integration/megatron/lora/test_merged_weight_export.py index c8135e90d..54a155f1f 100644 --- a/tests/integration/megatron/lora/test_merged_weight_export.py +++ b/tests/integration/megatron/lora/test_merged_weight_export.py @@ -113,26 +113,24 @@ def test_ensure_merged_weight_transfer_group_non_sender_skips_runtime_init( assert barriers == [] -def test_sync_merged_weights_to_vllm_non_sender_only_drains_export( +def test_sync_merged_weights_to_vllm_non_sender_only_builds_lora_payload( monkeypatch, ) -> None: spec = _spec() barrier_calls: list[int] = [] - iter_passes: list[int] = [] + build_ranks: list[int] = [] monkeypatch.setattr( export, "ensure_merged_weight_transfer_group", lambda **kwargs: (None, spec.init_info), ) - monkeypatch.setattr(export, "build_merged_weight_export", lambda **kwargs: object()) - - def fake_iter(_weight_export: object): - iter_passes.append(len(iter_passes) + 1) - yield ("layer.weight", torch.zeros((2, 3), dtype=torch.float16)) - yield ("layer.bias", torch.zeros((3,), dtype=torch.float32)) + monkeypatch.setattr( + export, + "build_vllm_lora_tensors_from_model", + lambda **kwargs: build_ranks.append(kwargs["rank"]) or None, + ) - monkeypatch.setattr(export, "iter_merged_vllm_weights", fake_iter) monkeypatch.setattr(export, "_maybe_distributed_barrier", barrier_calls.append) monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) monkeypatch.setattr( @@ -152,6 +150,8 @@ def fake_iter(_weight_export: object): bridge=object(), model=cast(Any, object()), model_support_handler=object(), + adapter_model={}, + adapter_config={}, rank=1, world_size=2, merged_weight_transfer_group=None, @@ -162,7 +162,7 @@ def fake_iter(_weight_export: object): assert group is None assert init_info == spec.init_info - assert iter_passes == [1, 2] + assert build_ranks == [1] assert barrier_calls == [2] @@ -181,11 +181,16 @@ def test_sync_merged_weights_to_vllm_sender_controls_runtime_and_sends( "ensure_merged_weight_transfer_group", lambda **kwargs: ("trainer-group", spec.init_info), ) - monkeypatch.setattr(export, "build_merged_weight_export", lambda **kwargs: object()) + published_config = {"r": 2, "lora_alpha": 4} - def fake_iter(_weight_export: object): - yield ("layer.weight", torch.zeros((2, 3), dtype=torch.float16)) - yield ("layer.bias", torch.zeros((3,), dtype=torch.float32)) + def fake_build(**kwargs): + return ( + { + "layer.b.lora_B.weight": torch.zeros((3,), dtype=torch.float32), + "layer.a.lora_A.weight": torch.zeros((2, 3), dtype=torch.float16), + }, + published_config, + ) def fake_send(iterator, trainer_args): sent_items.append(list(iterator)) @@ -210,7 +215,7 @@ def post( posts.append((url, json, params, timeout)) return _OkResponse() - monkeypatch.setattr(export, "iter_merged_vllm_weights", fake_iter) + monkeypatch.setattr(export, "build_vllm_lora_tensors_from_model", fake_build) monkeypatch.setattr(export, "trainer_send_weights", fake_send) monkeypatch.setattr(export, "_maybe_distributed_barrier", barrier_calls.append) monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) @@ -220,6 +225,8 @@ def post( bridge=object(), model=cast(Any, object()), model_support_handler=object(), + adapter_model={}, + adapter_config=published_config, rank=0, world_size=2, merged_weight_transfer_group=None, @@ -230,12 +237,15 @@ def post( assert group == "trainer-group" assert init_info == spec.init_info - assert [name for name, _ in sent_items[0]] == ["layer.weight", "layer.bias"] + assert [name for name, _ in sent_items[0]] == [ + "layer.a.lora_A.weight", + "layer.b.lora_B.weight", + ] assert posts == [ ("http://runtime.test/pause", None, {"mode": "wait"}, 300.0), ( "http://runtime.test/start_weight_update", - {"is_checkpoint_format": True}, + {"is_checkpoint_format": False}, None, 300.0, ), @@ -243,7 +253,12 @@ def post( "http://runtime.test/update_weights", { "update_info": { - "names": ["layer.weight", "layer.bias"], + "art_weight_update_kind": "lora_delta", + "art_lora_config": published_config, + "names": [ + "layer.a.lora_A.weight", + "layer.b.lora_B.weight", + ], "dtype_names": ["float16", "float32"], "shapes": [[2, 3], [3]], "packed": True, diff --git a/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py b/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py index 38b9a80ae..ee85f325b 100644 --- a/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py +++ b/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py @@ -14,12 +14,19 @@ def test_trainer_nccl_unique_id_round_trips_as_raw_bytes() -> None: assert nccl._nccl_unique_id_to_bytes(unique_id) == payload -def test_trainer_nccl_communicator_retains_bootstrap_group( +def test_trainer_nccl_communicator_releases_bootstrap_group_after_init( monkeypatch: pytest.MonkeyPatch, ) -> None: payload = bytes(range(128)) + bootstrap_closed = False + + def close_bootstrap() -> None: + nonlocal bootstrap_closed + bootstrap_closed = True + bootstrap_group = SimpleNamespace( - broadcast_obj=lambda obj, src: obj if obj is not None else payload + broadcast_obj=lambda obj, src: obj if obj is not None else payload, + close=close_bootstrap, ) loaded_so_paths: list[str | None] = [] @@ -63,7 +70,8 @@ def init_rank(self, world_size, unique_id, rank): device=0, nccl_so_path="/runtime/libnccl.so.2", ) - assert communicator._bootstrap_group is bootstrap_group + assert communicator._bootstrap_group is None + assert bootstrap_closed is True assert loaded_so_paths == ["/runtime/libnccl.so.2"] diff --git a/tests/integration/megatron/model_support/chat_template_rollout.py b/tests/integration/megatron/model_support/chat_template_rollout.py index 65e30622c..4067eedbf 100644 --- a/tests/integration/megatron/model_support/chat_template_rollout.py +++ b/tests/integration/megatron/model_support/chat_template_rollout.py @@ -246,26 +246,25 @@ def run_chat_template_rollout(base_model: str) -> ChatTemplateRolloutReport: ) ) - expected_error = "Assistant message has tool_calls" - observed_error: str | None = None - try: - tokenize_trajectory( - tokenizer=tokenizer, - image_processor=None, - history=_history(inputs.unsupported_assistant_tool_calls), - advantage=1.0, - allow_training_without_logprobs=True, - trajectory=inputs.unsupported_assistant_tool_calls, - ) - except ValueError as exc: - observed_error = str(exc) + unsupported_result = tokenize_trajectory( + tokenizer=tokenizer, + image_processor=None, + history=_history(inputs.unsupported_assistant_tool_calls), + advantage=1.0, + allow_training_without_logprobs=True, + trajectory=inputs.unsupported_assistant_tool_calls, + ) scenarios.append( ChatTemplateScenarioReport( - name="unsupported_assistant_tool_calls_without_logprobs", + name="rl_dict_assistant_tool_calls_without_choice_is_not_trainable", entrypoint="tokenize_trajectory", - passed=observed_error is not None and expected_error in observed_error, - expected_error_substring=expected_error, - observed_error=observed_error, + passed=unsupported_result is None, + result_count=int(unsupported_result is not None), + assistant_token_count=( + 0 + if unsupported_result is None + else int(sum(unsupported_result.assistant_mask)) + ), ) ) diff --git a/tests/integration/megatron/model_support/hf_parity_worker.py b/tests/integration/megatron/model_support/hf_parity_worker.py index 8279b3abf..85ba6898a 100644 --- a/tests/integration/megatron/model_support/hf_parity_worker.py +++ b/tests/integration/megatron/model_support/hf_parity_worker.py @@ -61,7 +61,12 @@ HF_PARITY_DEBUG_ENV = "ART_HF_PARITY_DEBUG" _DEBUG_START_TIME = time.perf_counter() _VISUAL_HF_PREFIXES = ("model.visual.", "visual.") -_HF_MOE_ROUTER_NAME_PATTERN = re.compile(r"^model\.layers\.(?P\d+)\.mlp\.gate$") +_HF_MOE_ROUTER_NAME_PATTERN = re.compile( + r"^(?:" + r"model\.layers\.(?P\d+)\.mlp\.gate|" + r"model(?:\.language_model)?\.layers\.(?P\d+)\.router" + r")$" +) _REPLAY_ROUTER_LAYER_PATTERN = re.compile( r"^chunk_\d+\.layer_(?P\d+)\.mlp\.router$" ) @@ -72,13 +77,40 @@ r"^model(?:\.language_model)?\.layers\.(?P\d+)\.mlp\.experts\." r"(?P\d+)\.(?:down_proj|gate_proj|up_proj)\.weight$" ) +_GEMMA4_ROUTER_PROJ_WEIGHT_PATTERN = re.compile( + r"^(?Pmodel(?:\.language_model)?\.layers\.\d+\.)" + r"router\.proj\.weight$" +) +_GEMMA4_SHARED_EXPERT_WEIGHT_PATTERN = re.compile( + r"^(?Pmodel(?:\.language_model)?\.layers\.\d+\.)" + r"mlp\.(?:gate_proj|up_proj)\.weight$" +) +_GEMMA4_ABSENT_V_PROJ_WEIGHT_PATTERN = re.compile( + r"^(?Pmodel(?:\.language_model)?\.layers\.\d+\.self_attn\.)" + r"v_proj\.weight$" +) +_GEMMA4_REPARAMETERIZED_NORM_GRAD_PATTERN = re.compile( + r"^model(?:\.language_model)?\.layers\.\d+\.pre_feedforward_layernorm_2\.weight$" +) def _hf_moe_router_key(module_name: str) -> str | None: match = _HF_MOE_ROUTER_NAME_PATTERN.match(module_name) if match is None: return None - return f"chunk_00.layer_{int(match.group('layer')):04d}.mlp.router" + layer = match.group("gate_layer") or match.group("router_layer") + return f"chunk_00.layer_{int(layer):04d}.mlp.router" + + +def _hf_router_num_experts(module: Any, router_scores: torch.Tensor) -> int: + config = getattr(module, "config", None) + return int( + getattr( + module, + "num_experts", + getattr(config, "num_experts", router_scores.shape[-1]), + ) + ) class _HfMoeRoutingCapture: @@ -172,9 +204,7 @@ def _hook(_module: Any, _inputs: Any, output: Any) -> None: expert_mask=torch.ones_like( router_indices.detach().cpu(), dtype=torch.bool ), - num_experts=int( - getattr(module, "num_experts", router_scores.shape[-1]) - ), + num_experts=_hf_router_num_experts(module, router_scores), sample_index=self._active_sample_index, micro_slot=( None @@ -619,6 +649,113 @@ def _filter_language_only_tensor_map( } +def _is_gemma4_model_bridge(model_bridge: Any) -> bool: + return "Gemma4" in type(model_bridge).__name__ + + +def _add_converted_hf_grad( + converted: dict[str, torch.Tensor], + additive_keys: set[str], + key: str, + value: torch.Tensor, + *, + additive: bool = False, +) -> None: + if key in converted: + converted[key] = converted[key] + value + else: + converted[key] = value + if additive: + additive_keys.add(key) + + +def _maybe_modify_converted_hf_grad( + model_bridge: Any, + task: Any, + converted_weights_dict: dict[str, torch.Tensor], + hf_state_dict: Any, +) -> tuple[dict[str, torch.Tensor], set[str]]: + if not _is_gemma4_model_bridge(model_bridge): + return ( + model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ), + set(), + ) + + converted: dict[str, torch.Tensor] = {} + additive_keys: set[str] = set() + for hf_name, tensor in converted_weights_dict.items(): + if hf_name not in hf_state_dict: + if match := _GEMMA4_ABSENT_V_PROJ_WEIGHT_PATTERN.match(hf_name): + k_name = f"{match.group('prefix')}k_proj.weight" + hf_state_dict[k_name] + _add_converted_hf_grad( + converted, + additive_keys, + k_name, + tensor.float(), + additive=True, + ) + continue + grad = tensor.float() + + if match := _GEMMA4_ROUTER_PROJ_WEIGHT_PATTERN.match(hf_name): + prefix = match.group("prefix") + scale = hf_state_dict[f"{prefix}router.scale"].float().to(grad.device) + ln2 = ( + hf_state_dict[f"{prefix}pre_feedforward_layernorm_2.weight"] + .float() + .to(grad.device) + ) + hf_weight = hf_state_dict[hf_name].float().to(grad.device) + root = grad.shape[-1] ** -0.5 + factor = scale * root / ln2 + # Gemma 4 imports fold HF preprocessing into MCore weights. Value + # export divides by this factor, but derivative export must apply the + # chain rule and accumulate the induced norm-weight gradient. + _add_converted_hf_grad(converted, additive_keys, hf_name, grad * factor) + _add_converted_hf_grad( + converted, + additive_keys, + f"{prefix}pre_feedforward_layernorm_2.weight", + (grad * hf_weight * (-scale * root / ln2.square()).unsqueeze(0)).sum( + dim=0 + ), + additive=True, + ) + continue + + if match := _GEMMA4_SHARED_EXPERT_WEIGHT_PATTERN.match(hf_name): + prefix = match.group("prefix") + pffl = ( + hf_state_dict[f"{prefix}pre_feedforward_layernorm.weight"] + .float() + .to(grad.device) + ) + ln2 = ( + hf_state_dict[f"{prefix}pre_feedforward_layernorm_2.weight"] + .float() + .to(grad.device) + ) + hf_weight = hf_state_dict[hf_name].float().to(grad.device) + factor = pffl / ln2 + _add_converted_hf_grad(converted, additive_keys, hf_name, grad * factor) + _add_converted_hf_grad( + converted, + additive_keys, + f"{prefix}pre_feedforward_layernorm_2.weight", + (grad * hf_weight * (-pffl / ln2.square()).unsqueeze(0)).sum(dim=0), + additive=True, + ) + continue + + _add_converted_hf_grad(converted, additive_keys, hf_name, tensor) + return converted, additive_keys + + def _convert_megatron_tasks_to_hf( runtime: megatron_train.TrainingRuntime, *, @@ -638,12 +775,14 @@ def _convert_megatron_tasks_to_hf( hf_state_dict = runtime.bridge.hf_pretrained.state grouped_buffers: dict[str, dict[int, torch.Tensor]] = {} converted: dict[str, torch.Tensor] = {} + additive_grad_keys: set[str] = set() for task in tasks: tensor = _megatron_task_tensor(task, mode=mode) converted_weights_dict = task.mapping.megatron_to_hf( tensor, task.megatron_module, ) + task_additive_grad_keys: set[str] = set() if getattr(task.mapping, "is_grouped_export", False): merged_result = model_bridge._accumulate_grouped_export( task, @@ -656,17 +795,36 @@ def _convert_megatron_tasks_to_hf( continue converted_weights_dict = merged_result else: - converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( - task, - converted_weights_dict, - hf_state_dict, - ) + if mode == "grad": + converted_weights_dict, task_additive_grad_keys = ( + _maybe_modify_converted_hf_grad( + model_bridge, + task, + converted_weights_dict, + hf_state_dict, + ) + ) + else: + converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ) for hf_name, value in converted_weights_dict.items(): if not _is_language_hf_param_name(hf_name): continue + value = value.detach().cpu().to(dtype=torch.float32) if hf_name in converted: + if mode == "grad" and ( + hf_name in additive_grad_keys or hf_name in task_additive_grad_keys + ): + converted[hf_name] = converted[hf_name] + value + additive_grad_keys.add(hf_name) + continue raise RuntimeError(f"Duplicate converted HF key '{hf_name}' in {mode}") - converted[hf_name] = value.detach().cpu().to(dtype=torch.float32) + converted[hf_name] = value + if hf_name in task_additive_grad_keys: + additive_grad_keys.add(hf_name) return converted @@ -801,6 +959,16 @@ def _normalize_hf_grads_for_bridge( } +def _drop_gemma4_reparameterized_norm_grads( + tensor_map: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + return { + key: value + for key, value in tensor_map.items() + if _GEMMA4_REPARAMETERIZED_NORM_GRAD_PATTERN.match(key) is None + } + + def _worker_run(request: HfParityRunRequest) -> None: if not torch.cuda.is_available(): raise RuntimeError("HF parity requires at least one CUDA device") @@ -871,6 +1039,16 @@ def _worker_run(request: HfParityRunRequest) -> None: hf_grads, expected_grad_keys=set(megatron_grads.keys()), ) + if "gemma-4" in request.case_config.base_model.lower(): + # Gemma 4 Bridge stores HF-only preprocessing parameters as buffers and + # folds them into Megatron weights. The fused linear gradients are + # compared after the chain-rule export above, but this norm's base + # gradient is not an independent HF-coordinate gradient in the reduced + # Megatron parameterization used by the shipped LoRA path. + normalized_hf_grads = _drop_gemma4_reparameterized_norm_grads( + normalized_hf_grads + ) + megatron_grads = _drop_gemma4_reparameterized_norm_grads(megatron_grads) active_embedding_rows = _active_embedding_token_rows(micro_inputs) active_router_rows = _active_router_rows_by_layer(moe_routing_replay_bundle) last_layer_index = request.case_config.num_layers - 1 diff --git a/tests/integration/megatron/model_support/oracle_harness.py b/tests/integration/megatron/model_support/oracle_harness.py index 1fd19cce3..d5bf7c581 100644 --- a/tests/integration/megatron/model_support/oracle_harness.py +++ b/tests/integration/megatron/model_support/oracle_harness.py @@ -89,7 +89,7 @@ ) NON_FINITE_METRIC_VALUE = 1e30 ORACLE_DEFAULT_MEAN_ABS_PCT_LIMIT = DEFAULT_MEAN_ABS_PCT_THRESHOLD -ROUTER_SCORE_MEAN_ABS_PCT_LIMIT = 2e-4 +ROUTER_SCORE_MEAN_ABS_PCT_LIMIT = 5e-4 FORWARD_EXPERT_LORA_TRACE_NOISE_RELATIVE_L2_LIMIT = 3e-4 FORWARD_EXPERT_LORA_TRACE_NOISE_REASON = "forward_expert_lora_trace_noise" EXPERT_TABLE_ROW_LIMIT = 8 diff --git a/tests/integration/megatron/model_support/oracle_worker.py b/tests/integration/megatron/model_support/oracle_worker.py index 86ad2fb0e..29a537469 100644 --- a/tests/integration/megatron/model_support/oracle_worker.py +++ b/tests/integration/megatron/model_support/oracle_worker.py @@ -1043,7 +1043,7 @@ def _apply_attention_lse_normalize_mutation(mutation: SensitivityMutation | None original_compiled = compiled_flex_attention.normalize_flex_lse original_executor = executor.normalize_flex_lse - def _identity(lse: torch.Tensor) -> torch.Tensor: + def _identity(lse: torch.Tensor, **_kwargs: Any) -> torch.Tensor: return lse compiled_flex_attention.normalize_flex_lse = _identity # type: ignore[invalid-assignment] diff --git a/tests/integration/megatron/model_support/packed_position_ids.py b/tests/integration/megatron/model_support/packed_position_ids.py index 44efe9047..e37a8893e 100644 --- a/tests/integration/megatron/model_support/packed_position_ids.py +++ b/tests/integration/megatron/model_support/packed_position_ids.py @@ -44,6 +44,17 @@ PACKED_POSITION_IDS_REPORT_FILENAME = "report.json" PACKED_POSITION_IDS_ARTIFACT_SUITE_NAME = "Megatron packed-position-id artifacts" REPO_ROOT = Path(__file__).resolve().parents[4] +_SINGLE_ROTARY_OUTPUT_HANDLER_KEYS = frozenset( + { + "default_dense", + "default_moe", + "qwen3_dense", + "qwen3_moe", + "qwen3_5_dense", + "qwen3_5_moe", + } +) +_TUPLE_ROTARY_OUTPUT_HANDLER_KEYS = frozenset({"gemma4_moe"}) def _slugify(value: str) -> str: @@ -248,6 +259,26 @@ def _rotary_grouping_check( return True, True, repeated_position_key_count +def _rotary_outputs_for_validation( + *, + handler_key: str, + preprocess_output: Any, +) -> tuple[torch.Tensor | None, ...]: + rotary_output = preprocess_output[1] + if handler_key in _SINGLE_ROTARY_OUTPUT_HANDLER_KEYS: + return ( + cast(torch.Tensor | None, rotary_output) + if torch.is_tensor(rotary_output) + else None, + ) + if handler_key in _TUPLE_ROTARY_OUTPUT_HANDLER_KEYS: + local_rotary, global_rotary = rotary_output + return local_rotary, global_rotary + raise RuntimeError( + f"Packed position validation has no rotary output mapping for {handler_key!r}" + ) + + def _build_art_realistic_packed_tensors( config: PackedTensorConfig, seed: int, @@ -544,6 +575,7 @@ def _logits_equivalence_check( *, model: Any, handler: Any, + provider: Any, input_ids: torch.Tensor, position_ids: torch.Tensor, group_ids: torch.Tensor, @@ -558,6 +590,11 @@ def _logits_equivalence_check( logits_abs_sum = 0.0 logits_ref_abs_sum = 0.0 logits_numel = 0 + sliding_windows = tuple( + dict.fromkeys( + int(window) for window in getattr(provider, "art_flex_sliding_windows", ()) + ) + ) for row_index in range(int(input_ids.shape[0])): row_group_ids = group_ids[row_index : row_index + 1] row_parent_ids = parent_ids[row_index : row_index + 1] @@ -570,9 +607,13 @@ def _logits_equivalence_check( packed_bias = create_shared_prefix_state( group_ids=row_group_ids, parent_ids=row_parent_ids, + input_pos=row_position_ids, + sliding_windows=sliding_windows, build_gdn_execution_spec=bool( getattr(handler, "build_gdn_execution_spec", False) ), + attention_head_dim=getattr(provider, "kv_channels", None), + attention_value_head_dim=getattr(provider, "kv_channels", None), ) _debug_log(f"logits_check row={row_index} families={len(families)}") packed_logits = _time_block( @@ -616,9 +657,13 @@ def _logits_equivalence_check( reference_bias = create_shared_prefix_state( group_ids=reference_group_ids, parent_ids=reference_parent_ids, + input_pos=reference_position_ids, + sliding_windows=sliding_windows, build_gdn_execution_spec=bool( getattr(handler, "build_gdn_execution_spec", False) ), + attention_head_dim=getattr(provider, "kv_channels", None), + attention_value_head_dim=getattr(provider, "kv_channels", None), ) _debug_log( "logits_check row=" @@ -742,7 +787,7 @@ def _run_packed_position_ids_worker( PackedTensorConfig( num_sequences=4, sequence_length=_env_int( - "ART_PACKED_POSITION_IDS_STOP_EARLY_SEQUENCE_LENGTH", 1024 + "ART_PACKED_POSITION_IDS_STOP_EARLY_SEQUENCE_LENGTH", 2048 ), prefill_tokens=_env_int( "ART_PACKED_POSITION_IDS_STOP_EARLY_PREFILL_TOKENS", 256 @@ -762,7 +807,7 @@ def _run_packed_position_ids_worker( PackedTensorConfig( num_sequences=4, sequence_length=_env_int( - "ART_PACKED_POSITION_IDS_TRUNCATE_SEQUENCE_LENGTH", 1024 + "ART_PACKED_POSITION_IDS_TRUNCATE_SEQUENCE_LENGTH", 2048 ), prefill_tokens=_env_int( "ART_PACKED_POSITION_IDS_TRUNCATE_PREFILL_TOKENS", 256 @@ -858,20 +903,28 @@ def _run_packed_position_ids_worker( ), device=row_input_ids.device, ) - rotary_output = hooked_output[1] - checked, respected, repeated_count = _rotary_grouping_check( - cast(torch.Tensor | None, rotary_output) - if torch.is_tensor(rotary_output) - else None, - position_ids=row_position_ids, + row_checked = False + row_respected = True + row_repeated_count = 0 + rotary_outputs = _rotary_outputs_for_validation( + handler_key=runtime.model_support_handler.key, + preprocess_output=hooked_output, ) - rotary_grouping_checked = rotary_grouping_checked or checked - rotary_grouping_respected = rotary_grouping_respected and respected - repeated_position_key_count += repeated_count + for rotary_output in rotary_outputs: + checked, respected, repeated_count = _rotary_grouping_check( + rotary_output, + position_ids=row_position_ids, + ) + row_checked = row_checked or checked + row_respected = row_respected and respected + row_repeated_count = repeated_count + rotary_grouping_checked = rotary_grouping_checked or row_checked + rotary_grouping_respected = rotary_grouping_respected and row_respected + repeated_position_key_count += row_repeated_count _debug_log( f"scenario {scenario_name} row={row_index} " - f"checked={checked} respected={respected} " - f"repeated_keys={repeated_count}" + f"checked={row_checked} respected={row_respected} " + f"repeated_keys={row_repeated_count}" ) ( completion_pair_count, @@ -883,6 +936,7 @@ def _run_packed_position_ids_worker( lambda: _logits_equivalence_check( model=model_chunks[0], handler=runtime.model_support_handler, + provider=runtime.provider, input_ids=input_ids, position_ids=position_ids, group_ids=group_ids, diff --git a/tests/integration/megatron/model_support/test_hf_parity_invariants.py b/tests/integration/megatron/model_support/test_hf_parity_invariants.py index be07ec6f6..df7105e1a 100644 --- a/tests/integration/megatron/model_support/test_hf_parity_invariants.py +++ b/tests/integration/megatron/model_support/test_hf_parity_invariants.py @@ -20,9 +20,13 @@ ) from .hf_parity_worker import ( _build_megatron_runtime, + _drop_gemma4_reparameterized_norm_grads, _filter_language_only_tensor_map, + _hf_moe_router_key, + _hf_router_num_experts, _is_language_hf_param_name, _mapping_supports_derivative_parity, + _maybe_modify_converted_hf_grad, _normalize_hf_grads_for_bridge, _normalize_hf_tensor_map_for_bridge, ) @@ -307,6 +311,26 @@ def test_normalize_hf_grads_for_bridge_keeps_expected_key_set() -> None: } +def test_hf_moe_routing_capture_recognizes_gemma4_router_names() -> None: + assert ( + _hf_moe_router_key("model.layers.3.mlp.gate") + == "chunk_00.layer_0003.mlp.router" + ) + assert ( + _hf_moe_router_key("model.language_model.layers.5.router") + == "chunk_00.layer_0005.mlp.router" + ) + assert _hf_moe_router_key("model.layers.7.router") == ( + "chunk_00.layer_0007.mlp.router" + ) + assert _hf_moe_router_key("model.language_model.layers.5.mlp.gate") is None + + +def test_hf_router_num_experts_uses_nested_config() -> None: + module = SimpleNamespace(config=SimpleNamespace(num_experts=128)) + assert _hf_router_num_experts(module, torch.ones(2, 8)) == 128 + + def test_build_megatron_runtime_uses_training_provider_bundle( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -369,3 +393,103 @@ def test_mapping_supports_derivative_parity_rejects_affine_weight_exports() -> N ) is False ) + + +class Gemma4BridgeForTest: + pass + + +def test_gemma4_router_grad_export_applies_chain_rule() -> None: + key = "model.language_model.layers.0.router.proj.weight" + prefix = "model.language_model.layers.0." + grad = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + hf_weight = torch.tensor([[5.0, 7.0], [11.0, 13.0]]) + scale = torch.tensor([3.0, 5.0]) + ln2 = torch.tensor([2.0, 4.0]) + + converted, additive = _maybe_modify_converted_hf_grad( + Gemma4BridgeForTest(), + SimpleNamespace(), + {key: grad}, + { + key: hf_weight, + f"{prefix}router.scale": scale, + f"{prefix}pre_feedforward_layernorm_2.weight": ln2, + }, + ) + + root = grad.shape[-1] ** -0.5 + factor = scale * root / ln2 + assert torch.allclose(converted[key], grad * factor) + assert torch.allclose( + converted[f"{prefix}pre_feedforward_layernorm_2.weight"], + (grad * hf_weight * (-scale * root / ln2.square()).unsqueeze(0)).sum(dim=0), + ) + assert additive == {f"{prefix}pre_feedforward_layernorm_2.weight"} + + +def test_gemma4_absent_v_grad_export_adds_to_k() -> None: + prefix = "model.language_model.layers.5.self_attn." + k_key = f"{prefix}k_proj.weight" + v_key = f"{prefix}v_proj.weight" + k_grad = torch.tensor([[1.0, 2.0]]) + v_grad = torch.tensor([[3.0, 4.0]]) + + converted, additive = _maybe_modify_converted_hf_grad( + Gemma4BridgeForTest(), + SimpleNamespace(), + {k_key: k_grad, v_key: v_grad}, + {k_key: torch.ones_like(k_grad)}, + ) + + assert torch.equal(converted[k_key], k_grad + v_grad) + assert additive == {k_key} + + +def test_drop_gemma4_reparameterized_norm_grads_is_exact() -> None: + kept_key = "model.language_model.layers.0.self_attn.q_norm.weight" + dropped_key = "model.language_model.layers.0.pre_feedforward_layernorm_2.weight" + filtered = _drop_gemma4_reparameterized_norm_grads( + { + kept_key: torch.ones(1), + dropped_key: torch.ones(1), + } + ) + + assert set(filtered) == {kept_key} + + +def test_gemma4_shared_expert_grad_export_applies_chain_rule() -> None: + prefix = "model.language_model.layers.0." + gate_key = f"{prefix}mlp.gate_proj.weight" + up_key = f"{prefix}mlp.up_proj.weight" + gate_grad = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + up_grad = torch.tensor([[5.0, 6.0], [7.0, 8.0]]) + gate_weight = torch.tensor([[2.0, 3.0], [5.0, 7.0]]) + up_weight = torch.tensor([[11.0, 13.0], [17.0, 19.0]]) + pffl = torch.tensor([3.0, 5.0]) + ln2 = torch.tensor([2.0, 4.0]) + + converted, additive = _maybe_modify_converted_hf_grad( + Gemma4BridgeForTest(), + SimpleNamespace(), + {gate_key: gate_grad, up_key: up_grad}, + { + gate_key: gate_weight, + up_key: up_weight, + f"{prefix}pre_feedforward_layernorm.weight": pffl, + f"{prefix}pre_feedforward_layernorm_2.weight": ln2, + }, + ) + + factor = pffl / ln2 + assert torch.allclose(converted[gate_key], gate_grad * factor) + assert torch.allclose(converted[up_key], up_grad * factor) + expected_ln2 = (gate_grad * gate_weight * (-pffl / ln2.square()).unsqueeze(0)).sum( + dim=0 + ) + (up_grad * up_weight * (-pffl / ln2.square()).unsqueeze(0)).sum(dim=0) + assert torch.allclose( + converted[f"{prefix}pre_feedforward_layernorm_2.weight"], + expected_ln2, + ) + assert additive == {f"{prefix}pre_feedforward_layernorm_2.weight"} diff --git a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py index 5a45bc03a..13bc6b35c 100644 --- a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py +++ b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py @@ -167,6 +167,8 @@ def test_production_compiled_flex_default_stays_flash() -> None: from art.megatron.flex_attn import compiled as compiled_flex_attention assert compiled_flex_attention._FORCED_FLEX_BACKEND == "FLASH" + assert compiled_flex_attention._FLASH_FLEX_KERNEL_OPTIONS == {"BACKEND": "FLASH"} + assert compiled_flex_attention._TRITON_FLEX_KERNEL_OPTIONS == {"BACKEND": "TRITON"} assert compiled_flex_attention._FORCED_FLEX_KERNEL_OPTIONS == {"BACKEND": "FLASH"} diff --git a/tests/integration/megatron/model_support/test_provider_support.py b/tests/integration/megatron/model_support/test_provider_support.py index 91620896f..bbb1447d0 100644 --- a/tests/integration/megatron/model_support/test_provider_support.py +++ b/tests/integration/megatron/model_support/test_provider_support.py @@ -338,8 +338,7 @@ def test_get_provider_bundle_honors_single_gpu_env_topology( assert resolved.recompute_method == "uniform" assert resolved.recompute_num_layers == 1 - transformer_layer_spec = cast(Any, resolved.transformer_layer_spec) - layer_spec = transformer_layer_spec(resolved, vp_stage=0) + layer_spec = resolved.transformer_layer_spec(resolved, vp_stage=0) assert ( layer_spec.submodules.self_attention.submodules.core_attention is FlexDotProductAttention diff --git a/tests/integration/megatron/model_support/test_workflow.py b/tests/integration/megatron/model_support/test_workflow.py index 0a9fec8d5..84c2f439f 100644 --- a/tests/integration/megatron/model_support/test_workflow.py +++ b/tests/integration/megatron/model_support/test_workflow.py @@ -1,6 +1,8 @@ import os from types import SimpleNamespace +import pytest + from art.megatron.model_support.spec import ( ArchitectureReport, LayerFamilyInstance, @@ -18,6 +20,7 @@ build_validation_stage_names, run_chat_template_rollout_stage, run_correctness_sensitivity_stage, + run_length_trainability_stage, run_lora_coverage_stage, run_merged_vllm_serving_stage, run_native_vllm_lora_stage, @@ -28,6 +31,21 @@ ) +@pytest.fixture(autouse=True) +def _stub_pinned_git_state(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.pinned_git_state", + lambda suite_name: SimpleNamespace( + model_dump=lambda mode="json": { + "path": "/tmp/art", + "commit": "test", + "dirty": False, + "status": [], + } + ), + ) + + def test_build_validation_stage_names_has_fixed_order() -> None: assert build_validation_stage_names() == list(MANDATORY_VALIDATION_STAGES) assert build_validation_stage_names(include_native_vllm_lora=True) == [ @@ -38,6 +56,10 @@ def test_build_validation_stage_names_has_fixed_order() -> None: *MANDATORY_VALIDATION_STAGES, NATIVE_VLLM_LORA_STAGE, ] + assert build_validation_stage_names(include_yes_no_trainability=True) == [ + *MANDATORY_VALIDATION_STAGES, + "yes_no_trainability", + ] def test_validated_architecture_representative_models_are_fixed() -> None: @@ -92,12 +114,14 @@ def test_build_all_architectures_validation_report_stops_on_failure( def _build_validation_report( *, base_model, + include_yes_no_trainability=False, include_sensitivity=None, output_json=None, skip_stages=None, stop_on_failure=False, allow_unvalidated_arch=False, ): + del include_yes_no_trainability del include_sensitivity del output_json del skip_stages @@ -208,14 +232,14 @@ def test_build_validation_report_populates_architecture_stage( }, artifact_dir="/tmp/packed-position-ids", ), - "yes_no_trainability": ValidationStageResult( - name="yes_no_trainability", + "length_trainability": ValidationStageResult( + name="length_trainability", passed=True, metrics={ - "latest_step": 3, - "final_eval_reward": 0.97, + "latest_step": 4, + "best_train_abs_error": 1.0, }, - artifact_dir="/tmp/trainability", + artifact_dir="/tmp/length-trainability", ), "native_vllm_lora": ValidationStageResult( name="native_vllm_lora", @@ -319,14 +343,15 @@ def test_build_validation_report_populates_architecture_stage( } assert position_id_stage.artifact_dir == "/tmp/packed-position-ids" trainability_stage = next( - stage for stage in report.stages if stage.name == "yes_no_trainability" + stage for stage in report.stages if stage.name == "length_trainability" ) assert trainability_stage.passed is True assert trainability_stage.metrics == { - "latest_step": 3, - "final_eval_reward": 0.97, + "latest_step": 4, + "best_train_abs_error": 1.0, } - assert trainability_stage.artifact_dir == "/tmp/trainability" + assert trainability_stage.artifact_dir == "/tmp/length-trainability" + assert all(stage.name != "yes_no_trainability" for stage in report.stages) native_vllm_lora_stage = next( stage for stage in report.stages if stage.name == "native_vllm_lora" ) @@ -667,20 +692,63 @@ def test_run_yes_no_trainability_stage(monkeypatch) -> None: assert result.artifact_dir == "/tmp/trainability" +def test_run_length_trainability_stage(monkeypatch) -> None: + report = SimpleNamespace( + summary_log_path="/tmp/length-trainability/length_trainability.log", + model_dump=lambda mode="json": { + "latest_step": 3, + "initial_train_abs_error": 12.0, + "best_train_abs_error": 1.0, + }, + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: SimpleNamespace( + run_length_trainability=lambda *, base_model, allow_unvalidated_arch=False: ( + report + ), + length_trainability_passed=lambda candidate: candidate is report, + ), + ) + + result = run_length_trainability_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + ), + ) + + assert result.name == "length_trainability" + assert result.passed is True + assert result.artifact_dir == "/tmp/length-trainability" + + def test_run_train_inf_mismatch_stage(monkeypatch) -> None: + seen: dict[str, object] = {} + + def _run_train_inf_mismatch( + *, + base_model: str, + allow_unvalidated_arch: bool, + ) -> SimpleNamespace: + seen["allow_unvalidated_arch"] = allow_unvalidated_arch + return SimpleNamespace( + passed=True, + artifact_dir="/tmp/train-inf-mismatch", + model_dump=lambda mode="json": { + "base_model": base_model, + "passed": True, + "passed_count": 1, + "failed_count": 0, + }, + ) + monkeypatch.setattr( "tests.integration.megatron.model_support.workflow._import_integration_module", lambda name: SimpleNamespace( - run_train_inf_mismatch=lambda *, base_model: SimpleNamespace( - passed=True, - artifact_dir="/tmp/train-inf-mismatch", - model_dump=lambda mode="json": { - "base_model": base_model, - "passed": True, - "passed_count": 1, - "failed_count": 0, - }, - ) + run_train_inf_mismatch=_run_train_inf_mismatch, ), ) @@ -691,11 +759,13 @@ def test_run_train_inf_mismatch_stage(monkeypatch) -> None: model_key="qwen3_5_moe", handler_key="qwen3_5_moe", ), + allow_unvalidated_arch=True, ) assert result.name == "train_inf_mismatch" assert result.passed is True assert result.artifact_dir == "/tmp/train-inf-mismatch" + assert seen == {"allow_unvalidated_arch": True} assert result.metrics == { "base_model": "Qwen/Qwen3.5-35B-A3B", "passed": True, diff --git a/tests/integration/megatron/model_support/workflow.py b/tests/integration/megatron/model_support/workflow.py index 88f389b21..d987d7fb3 100644 --- a/tests/integration/megatron/model_support/workflow.py +++ b/tests/integration/megatron/model_support/workflow.py @@ -49,9 +49,10 @@ "correctness_sensitivity", "chat_template_rollout", "packed_position_ids", - "yes_no_trainability", + "length_trainability", ) NATIVE_VLLM_LORA_STAGE = "native_vllm_lora" +YES_NO_TRAINABILITY_STAGE = "yes_no_trainability" ARCHITECTURE_REPRESENTATIVE_MODELS = { "qwen3_moe": "Qwen/Qwen3-30B-A3B", "qwen3_dense": "Qwen/Qwen3-32B", @@ -67,7 +68,8 @@ "correctness_sensitivity", "chat_template_rollout", "packed_position_ids", - "yes_no_trainability", + "length_trainability", + YES_NO_TRAINABILITY_STAGE, NATIVE_VLLM_LORA_STAGE, } ) @@ -81,9 +83,12 @@ class AllArchitecturesValidationReport(BaseModel): def build_validation_stage_names( *, include_native_vllm_lora: bool = False, + include_yes_no_trainability: bool = False, native_vllm_lora_status: NativeVllmLoraStatus | None = None, ) -> list[str]: stages = list(MANDATORY_VALIDATION_STAGES) + if include_yes_no_trainability: + stages.append(YES_NO_TRAINABILITY_STAGE) if include_native_vllm_lora or native_vllm_lora_status not in {None, "disabled"}: stages.append(NATIVE_VLLM_LORA_STAGE) return stages @@ -103,6 +108,7 @@ def initialize_validation_report( *, base_model: str, include_native_vllm_lora: bool = False, + include_yes_no_trainability: bool = False, allow_unvalidated_arch: bool = False, ) -> ValidationReport: spec = get_model_support_spec( @@ -119,6 +125,7 @@ def initialize_validation_report( ValidationStageResult(name=stage_name) for stage_name in build_validation_stage_names( include_native_vllm_lora=include_native_vllm_lora, + include_yes_no_trainability=include_yes_no_trainability, native_vllm_lora_status=handler.native_vllm_lora_status, ) ], @@ -417,11 +424,13 @@ def run_train_inf_mismatch_stage( allow_unvalidated_arch: bool = False, ) -> ValidationStageResult: del architecture - del allow_unvalidated_arch train_inf_mismatch = _import_integration_module( "integration.megatron.train_inf_mismatch.workflow_stage" ) - report = train_inf_mismatch.run_train_inf_mismatch(base_model=base_model) + report = train_inf_mismatch.run_train_inf_mismatch( + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) return ValidationStageResult( name="train_inf_mismatch", passed=report.passed, @@ -666,13 +675,35 @@ def run_yes_no_trainability_stage( and report.final_eval_reward > report.initial_eval_reward ) return ValidationStageResult( - name="yes_no_trainability", + name=YES_NO_TRAINABILITY_STAGE, passed=passed, metrics=report.model_dump(mode="json"), artifact_dir=report.output_dir, ) +def run_length_trainability_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + del architecture + length_trainability = _import_integration_module( + "integration.megatron.trainability.test_live_length_trainability" + ) + report = length_trainability.run_length_trainability( + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + return ValidationStageResult( + name="length_trainability", + passed=length_trainability.length_trainability_passed(report), + metrics=report.model_dump(mode="json"), + artifact_dir=str(Path(report.summary_log_path).parent), + ) + + def run_native_vllm_lora_stage( *, base_model: str, @@ -747,6 +778,7 @@ def build_validation_report( *, base_model: str, include_native_vllm_lora: bool = False, + include_yes_no_trainability: bool = False, include_sensitivity: bool | None = None, output_json: str | Path | None = None, skip_stages: set[str] | None = None, @@ -756,6 +788,7 @@ def build_validation_report( report = initialize_validation_report( base_model=base_model, include_native_vllm_lora=include_native_vllm_lora, + include_yes_no_trainability=include_yes_no_trainability, allow_unvalidated_arch=allow_unvalidated_arch, ) stage_runners = { @@ -766,7 +799,8 @@ def build_validation_report( "correctness_sensitivity": run_correctness_sensitivity_stage, "chat_template_rollout": run_chat_template_rollout_stage, "packed_position_ids": run_packed_position_ids_stage, - "yes_no_trainability": run_yes_no_trainability_stage, + "length_trainability": run_length_trainability_stage, + YES_NO_TRAINABILITY_STAGE: run_yes_no_trainability_stage, NATIVE_VLLM_LORA_STAGE: run_native_vllm_lora_stage, } env = ( @@ -851,6 +885,7 @@ def build_validation_report( def build_all_architectures_validation_report( *, + include_yes_no_trainability: bool = False, include_sensitivity: bool | None = None, output_json: str | Path | None = None, skip_stages: set[str] | None = None, @@ -866,6 +901,7 @@ def build_all_architectures_validation_report( ).key report = build_validation_report( base_model=base_model, + include_yes_no_trainability=include_yes_no_trainability, include_sensitivity=include_sensitivity, output_json=( _per_architecture_output_json(output_json, model_key) @@ -897,6 +933,7 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument("--output-json", required=True) parser.add_argument("--allow-unsupported-arch", action="store_true") parser.add_argument("--include-sensitivity", action="store_true") + parser.add_argument("--include-yes-no-trainability", action="store_true") parser.add_argument("--skip-stage", action="append", default=[]) parser.add_argument("--stop-on-failure", action="store_true") return parser.parse_args(argv) @@ -906,6 +943,7 @@ def main(argv: list[str] | None = None) -> int: args = _parse_args(argv) if args.all_architectures: all_report = build_all_architectures_validation_report( + include_yes_no_trainability=args.include_yes_no_trainability, include_sensitivity=args.include_sensitivity, output_json=args.output_json, skip_stages=set(args.skip_stage), @@ -925,6 +963,7 @@ def main(argv: list[str] | None = None) -> int: return 0 if all_report.passed else 1 report = build_validation_report( base_model=args.base_model, + include_yes_no_trainability=args.include_yes_no_trainability, include_sensitivity=args.include_sensitivity, output_json=args.output_json, skip_stages=set(args.skip_stage), diff --git a/tests/integration/megatron/model_support/workflow_stage_worker.py b/tests/integration/megatron/model_support/workflow_stage_worker.py index c854259fa..f12456d24 100644 --- a/tests/integration/megatron/model_support/workflow_stage_worker.py +++ b/tests/integration/megatron/model_support/workflow_stage_worker.py @@ -7,6 +7,7 @@ run_chat_template_rollout_stage, run_correctness_sensitivity_stage, run_hf_parity_stage, + run_length_trainability_stage, run_lora_coverage_stage, run_merged_vllm_serving_stage, run_native_vllm_lora_stage, @@ -23,6 +24,7 @@ "correctness_sensitivity": run_correctness_sensitivity_stage, "chat_template_rollout": run_chat_template_rollout_stage, "packed_position_ids": run_packed_position_ids_stage, + "length_trainability": run_length_trainability_stage, "yes_no_trainability": run_yes_no_trainability_stage, "native_vllm_lora": run_native_vllm_lora_stage, } diff --git a/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py b/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py index 7cc102473..e81858fcc 100644 --- a/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py +++ b/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py @@ -198,6 +198,17 @@ async def _megatron_backend_context( ) -> AsyncIterator[MegatronBackend]: with _wandb_disabled(): with provider_topology_env(topology): + art.init_megatron_runtime_config( + topology=art.MegatronTopologyConfig( + tp=topology.tp, + cp=topology.cp, + ep=topology.ep, + pp=topology.pp, + vpp=topology.vpp if topology.vpp != 1 else None, + etp=topology.etp, + ), + packed_sequence_length=_packed_sequence_length(), + ) async with MegatronBackend( path=str(backend_root), in_process=False ) as backend: @@ -286,7 +297,6 @@ async def test_megatron_backend_shared_lora_runtime_sleep_wake_live_smoke( train_groups, learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), loss_fn="cispo", - packed_sequence_length=_packed_sequence_length(), ) ) observed_sleep = False @@ -379,7 +389,6 @@ async def test_megatron_backend_dedicated_merged_live_smoke( train_groups, learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), loss_fn="cispo", - packed_sequence_length=_packed_sequence_length(), ) latest_step = int(result.step) latest_name = model.get_inference_name(step=latest_step) @@ -453,7 +462,6 @@ async def test_megatron_backend_dedicated_multirank_merged_live_smoke( train_groups, learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), loss_fn="cispo", - packed_sequence_length=_packed_sequence_length(), ) latest_step = int(result.step) latest_name = model.get_inference_name(step=latest_step) @@ -535,7 +543,6 @@ async def test_megatron_backend_shared_lora_ten_step_live_smoke( train_groups, learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), loss_fn="cispo", - packed_sequence_length=_packed_sequence_length(), ) ) observed_sleep = False diff --git a/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py b/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py index 10d9edc3c..959d72e92 100644 --- a/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py +++ b/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py @@ -87,6 +87,89 @@ def test_runtime_general_plugin_loads_full_patch_set() -> None: assert 'art = "art_vllm_runtime.patches:apply_vllm_runtime_patches"' in pyproject +def test_runtime_patch_adds_gemma4_moe_topk_alias(artifact_dir: Path) -> None: + result = subprocess.run( + [ + "uv", + "run", + "--project", + str(ROOT / "vllm_runtime"), + "python", + "-c", + ( + "import json; " + "from art_vllm_runtime.patches import apply_vllm_runtime_patches; " + "apply_vllm_runtime_patches(); " + "from transformers import Gemma4TextConfig; " + "config = Gemma4TextConfig(enable_moe_block=True, top_k_experts=8); " + "print(json.dumps({'num_experts_per_tok': config.num_experts_per_tok}))" + ), + ], + cwd=ROOT, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "gemma4_topk_alias_stdout.txt").write_text(result.stdout) + (artifact_dir / "gemma4_topk_alias_stderr.txt").write_text(result.stderr) + assert json.loads(result.stdout.strip()) == {"num_experts_per_tok": 8} + + +def test_runtime_patch_skips_gemma4_layerwise_weight_update_reload( + artifact_dir: Path, +) -> None: + result = subprocess.run( + [ + "uv", + "run", + "--project", + str(ROOT / "vllm_runtime"), + "python", + "-c", + ( + "import json; " + "from art_vllm_runtime.patches import apply_vllm_runtime_patches; " + "apply_vllm_runtime_patches(); " + "from vllm.v1.worker.gpu_worker import Worker; " + "HfConfig = type('HfConfig', (), {" + "'architectures': ['Gemma4ForConditionalGeneration']" + "}); " + "ModelConfig = type('ModelConfig', (), {'hf_config': HfConfig()}); " + "DummyWorker = type('DummyWorker', (), {" + "'model_config': ModelConfig(), " + "'_weight_update_active': False, " + "'_is_checkpoint_format': True, " + "'checks': 0, " + "'_check_weight_transfer_engine': " + "lambda self: setattr(self, 'checks', self.checks + 1)" + "}); " + "dummy = DummyWorker(); " + "Worker.start_weight_update(dummy, is_checkpoint_format=True); " + "active_after_start = dummy._weight_update_active; " + "Worker.finish_weight_update(dummy); " + "print(json.dumps({" + "'active_after_start': active_after_start, " + "'active_after_finish': dummy._weight_update_active, " + "'is_checkpoint_format': dummy._is_checkpoint_format, " + "'checks': dummy.checks" + "}))" + ), + ], + cwd=ROOT, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "gemma4_weight_update_reload_stdout.txt").write_text(result.stdout) + (artifact_dir / "gemma4_weight_update_reload_stderr.txt").write_text(result.stderr) + assert json.loads(result.stdout.strip()) == { + "active_after_start": True, + "active_after_finish": False, + "is_checkpoint_format": True, + "checks": 2, + } + + def test_runtime_patch_set_does_not_install_lora_monkey_patches() -> None: source = ( ROOT / "vllm_runtime" / "src" / "art_vllm_runtime" / "patches.py" diff --git a/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py b/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py index 43242e256..36d0434f6 100644 --- a/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py +++ b/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py @@ -9,9 +9,16 @@ import httpx import pytest +import art from art.megatron.service import MegatronService -from art.types import MegatronTopologyConfig -from art.unsloth.service import UnslothService + + +@pytest.fixture(autouse=True) +def _init_megatron_runtime_config() -> None: + art.init_megatron_runtime_config( + topology=art.MegatronTopologyConfig(tp=1, cp=2, ep=2, etp=1), + packed_sequence_length=1024, + ) class _AsyncOkResponse: @@ -100,7 +107,8 @@ async def test_unsloth_shared_start_requires_runtime_sleep_mode( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - service = UnslothService( + unsloth_service = pytest.importorskip("art.unsloth.service") + service = unsloth_service.UnslothService( model_name="test-model", base_model="Qwen/Qwen3-0.6B", config={ @@ -156,7 +164,8 @@ async def test_unsloth_runtime_sleep_and_wake_use_runtime_routes( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - service = UnslothService( + unsloth_service = pytest.importorskip("art.unsloth.service") + service = unsloth_service.UnslothService( model_name="test-model", base_model="Qwen/Qwen3-0.6B", config={"rollout_weights_mode": "lora"}, @@ -204,7 +213,6 @@ async def test_megatron_dedicated_merged_start_syncs_initial_weights( sync_merged.assert_awaited_once_with( lora_path="/tmp/lora", step=0, - megatron_topology=None, ) @@ -220,7 +228,6 @@ async def test_megatron_dedicated_merged_start_uses_configured_topology( "trainer_gpu_ids": [0], "inference_gpu_ids": [1], "rollout_weights_mode": "merged", - "megatron_topology": {"tp": 1, "cp": 2, "ep": 2, "etp": 1}, }, output_dir=str(tmp_path), ) @@ -235,8 +242,8 @@ async def test_megatron_dedicated_merged_start_uses_configured_topology( sync_merged.assert_awaited_once_with( lora_path="/tmp/lora", step=0, - megatron_topology=MegatronTopologyConfig(tp=1, cp=2, ep=2, etp=1), ) + assert service.runtime_config.topology.cp == 2 @pytest.mark.asyncio diff --git a/tests/integration/megatron/train_inf_mismatch/output_parity.py b/tests/integration/megatron/train_inf_mismatch/output_parity.py index 5a2125c85..d128865ec 100644 --- a/tests/integration/megatron/train_inf_mismatch/output_parity.py +++ b/tests/integration/megatron/train_inf_mismatch/output_parity.py @@ -24,10 +24,16 @@ # 4.606% mean_abs_pct while staying under the KL gate, so its gate is 5%. BF16_FWD_MEAN_ABS_PCT_LIMIT = 4.0 BF16_FWD_MEAN_ABS_PCT_LIMIT_BY_MODEL_KEY = { - "qwen3_moe": 8.0, + # Gemma 4 MoE long-prompt SWA native-LoRA runs showed high variation, with + # repeated samples reaching 7.6% mean_abs_pct and 0.0076 KL. + "gemma4_moe": 8.0, + "qwen3_moe": 7.0, "qwen3_5_moe": 5.0, } TOP20_KL_CANDIDATE_TO_TARGET_LIMIT = 0.002 +TOP20_KL_CANDIDATE_TO_TARGET_LIMIT_BY_MODEL_KEY = { + "gemma4_moe": 0.008, +} MEAN_ABS_PCT_DENOMINATOR_EPS = 1e-18 TOP_K = 20 ScoreRecord = tuple[int, float, list[int], list[float]] @@ -96,7 +102,6 @@ class TrainInfOutputParityConfig(BaseModel): lora_target_modules: list[str] | None = None engine_args: dict[str, Any] = Field(default_factory=dict) server_args: dict[str, Any] = Field(default_factory=dict) - replay_vllm_routing: bool = False @model_validator(mode="after") def _set_default_rollout_modes(self) -> "TrainInfOutputParityConfig": @@ -267,11 +272,14 @@ def top20_kl_candidate_to_target_limit_for_model( ) -> float: from art.megatron.model_support.registry import get_model_support_spec - get_model_support_spec( + spec = get_model_support_spec( base_model, allow_unvalidated_arch=allow_unvalidated_arch, ) - return TOP20_KL_CANDIDATE_TO_TARGET_LIMIT + return TOP20_KL_CANDIDATE_TO_TARGET_LIMIT_BY_MODEL_KEY.get( + spec.key, + TOP20_KL_CANDIDATE_TO_TARGET_LIMIT, + ) def model_support_is_moe( @@ -865,6 +873,11 @@ def _run_logits( attention_state = create_shared_prefix_state( group_ids=group_ids, parent_ids=parent_ids, + input_pos=position_ids, + sliding_windows=tuple( + int(window) + for window in getattr(runtime.provider, "art_flex_sliding_windows", ()) + ), build_gdn_execution_spec=bool( getattr(runtime.model_support_handler, "build_gdn_execution_spec", False) ), @@ -1023,6 +1036,7 @@ def _score_context_parallel_once( import torch import torch.distributed as dist + dist_any = cast(Any, dist) from art.megatron.context_parallel.types import ParallelTopology from art.megatron.training.microbatches import _prepare_current_rl_micro from art.megatron.training.trace import ( @@ -1090,9 +1104,9 @@ def _score_context_parallel_once( desired_uids=set(logical_uids), ) gathered_records: list[dict[int, ScoreRecord]] = [ - {} for _ in range(dist.get_world_size()) + {} for _ in range(dist_any.get_world_size()) ] - dist.all_gather_object(gathered_records, local_records) + dist_any.all_gather_object(gathered_records, local_records) return _score_bundle_from_records( records=_merge_score_records(gathered_records), logical_tokens=logical_tokens, diff --git a/tests/integration/megatron/train_inf_mismatch/real_path.py b/tests/integration/megatron/train_inf_mismatch/real_path.py index 07bce52f8..5cde6ced6 100644 --- a/tests/integration/megatron/train_inf_mismatch/real_path.py +++ b/tests/integration/megatron/train_inf_mismatch/real_path.py @@ -16,7 +16,12 @@ from openai.types.chat.chat_completion import Choice from pydantic import BaseModel, ConfigDict, Field -from art.preprocessing.moe_routing import choice_moe_routing_metadata +from art.dev.model import RolloutWeightsMode +from art.preprocessing.moe_routing import ( + MoeRoutingPackStats, + PackedMoeRoutingReplay, + choice_moe_routing_metadata, +) from art.preprocessing.pack import DiskPackedTensors from .artifacts import REPO_ROOT @@ -24,6 +29,7 @@ TOP_K, LogicalTokenMap, PairComparison, + RolloutMode, ScoreBundle, TokenTopK, TopKComparison, @@ -60,6 +66,8 @@ class RealPathConfig(BaseModel): rollouts_per_prompt: int = 2 max_completion_tokens: int = 16 prompt_sentence_count: int = 28 + prompt_target_tokens: int | None = None + sliding_window: int | None = None diagnose_base: bool = False trace_layers: bool = False trace_enforce_eager: bool = False @@ -127,6 +135,16 @@ class RealPathTrainInfReport(BaseModel): passed: bool +def _real_path_rollout_mode(config: TrainInfOutputParityConfig) -> RolloutMode: + return config.rollout_modes[0] + + +def _real_path_rollout_weights_mode( + config: TrainInfOutputParityConfig, +) -> RolloutWeightsMode: + return "lora" if _real_path_rollout_mode(config) == "native_lora" else "merged" + + _PROMPT_SENTENCES = [ "A careful systems engineer checks assumptions before changing thresholds.", "The training batch contains shared prefixes and divergent completions.", @@ -153,6 +171,7 @@ class RealPathTrainInfReport(BaseModel): "The run should not update weights just to measure a forward mismatch.", "Validation code belongs in tests unless production needs the behavior.", ] +_PROMPT_TOKENS_PER_SENTENCE_ESTIMATE = 12 def config_from_env() -> RealPathConfig: @@ -167,6 +186,8 @@ def config_from_env() -> RealPathConfig: config.max_completion_tokens = int(raw) if raw := os.environ.get("ART_REAL_PATH_PROMPT_SENTENCE_COUNT"): config.prompt_sentence_count = int(raw) + if raw := os.environ.get("ART_REAL_PATH_PROMPT_TARGET_TOKENS"): + config.prompt_target_tokens = int(raw) if raw := os.environ.get("ART_REAL_PATH_DIAGNOSE_BASE"): config.diagnose_base = raw == "1" if raw := os.environ.get("ART_REAL_PATH_TRACE_LAYERS"): @@ -178,6 +199,59 @@ def config_from_env() -> RealPathConfig: return config +def _round_up(value: int, multiple: int) -> int: + return ((value + multiple - 1) // multiple) * multiple + + +def _config_sliding_window(config: TrainInfOutputParityConfig) -> int | None: + from huggingface_hub import hf_hub_download + + local_config_path = Path(config.base_model) / "config.json" + config_path = ( + local_config_path + if local_config_path.exists() + else Path(hf_hub_download(config.base_model, "config.json")) + ) + hf_config = _read_json(config_path) + text_config = hf_config.get("text_config", hf_config) + if not isinstance(text_config, dict): + return None + layer_types = tuple(str(value) for value in text_config.get("layer_types", ())) + if not any("sliding" in layer_type for layer_type in layer_types): + return None + window = text_config.get("sliding_window") + if window is None: + return None + return int(window) + + +def _apply_sliding_window_prompt_defaults(config: RealPathConfig) -> None: + window = _config_sliding_window(config.output_parity) + if window is None: + return + config.sliding_window = window + if config.prompt_target_tokens is None: + config.prompt_target_tokens = 2 * window + config.prompt_sentence_count = max( + config.prompt_sentence_count, + (config.prompt_target_tokens + _PROMPT_TOKENS_PER_SENTENCE_ESTIMATE - 1) + // _PROMPT_TOKENS_PER_SENTENCE_ESTIMATE, + ) + min_sequence_length = _round_up( + config.prompt_target_tokens + config.max_completion_tokens + 256, + 128, + ) + if config.output_parity.packed.sequence_length < min_sequence_length: + config.output_parity.packed.sequence_length = min_sequence_length + + +def _build_prompt_from_sentences(index: int, sentences: list[str]) -> str: + return ( + "Write a concise continuation for probe " + f"{index}. Preserve the technical tone.\n\n" + " ".join(sentences) + ) + + def _build_prompts(config: RealPathConfig) -> list[str]: rng = random.Random(config.output_parity.seed) prompts: list[str] = [] @@ -185,10 +259,7 @@ def _build_prompts(config: RealPathConfig) -> list[str]: sentences = [ rng.choice(_PROMPT_SENTENCES) for _ in range(config.prompt_sentence_count) ] - prompts.append( - "Write a concise continuation for probe " - f"{index}. Preserve the technical tone.\n\n" + " ".join(sentences) - ) + prompts.append(_build_prompt_from_sentences(index, sentences)) return prompts @@ -411,6 +482,7 @@ def _vllm_scores_from_real_choices( logical_map: LogicalTokenMap, require_routing_metadata: bool, weight_state: WeightState, + rollout_mode: RolloutMode, ) -> ScoreBundle: choices_by_tokens = _choice_score_index( trajectory_groups, @@ -470,7 +542,7 @@ def _vllm_scores_from_real_choices( return ScoreBundle( side="vllm", weight_state=weight_state, - rollout_mode="native_lora", + rollout_mode=rollout_mode, target_logprobs=target_logprobs, topk=topk, ) @@ -484,7 +556,6 @@ async def _score_base_real_generation_path( ) -> RealPathBaseDiagnosticBundle: import art from art.megatron.backend import MegatronBackend - from art.preprocessing.moe_routing import MoeRoutingPackStats from art.preprocessing.pack import packed_tensors_to_dir parity_config = config.output_parity @@ -503,7 +574,6 @@ async def _score_base_real_generation_path( engine_args.pop("lora_target_modules", None) if is_moe: engine_args["enable_return_routed_experts"] = True - engine_args["async_scheduling"] = False vllm_forward_trace_dir = ( artifact_dir / "real_path_base_vllm_forward_trace" if config.trace_layers @@ -570,6 +640,7 @@ async def _score_base_real_generation_path( logical_map=logical_map, require_routing_metadata=is_moe, weight_state="base", + rollout_mode="merged", ) vllm_score_path = artifact_dir / "real_path_vllm_base_scores.json" _write_json(vllm_score_path, vllm_base.model_dump(mode="json")) @@ -584,7 +655,10 @@ async def _score_base_real_generation_path( global_grad_accumulation_sequences=global_grad_accumulation_sequences, ).to_dir(routing_replay_dir) routing_replay_path = str(routing_replay_dir) - stats = packed_tensors["moe_routing_replay"].pack_stats + routing_replay = cast( + PackedMoeRoutingReplay, packed_tensors["moe_routing_replay"] + ) + stats = routing_replay.pack_stats else: stats = MoeRoutingPackStats() @@ -668,6 +742,21 @@ def _routing_topology_from_config(config: TrainInfOutputParityConfig) -> Any: ) +def _init_art_megatron_runtime_config(config: TrainInfOutputParityConfig) -> None: + import art + + art.init_megatron_runtime_config( + topology=art.MegatronTopologyConfig( + tp=config.topology.tp, + cp=config.topology.cp, + ep=config.topology.ep, + pp=config.topology.pp, + etp=config.topology.etp, + ), + packed_sequence_length=config.packed.sequence_length, + ) + + def _build_real_path_moe_routing_replay_bundle( *, packed_tensors: Any, @@ -754,6 +843,7 @@ def _score_megatron_runtime( packed_tensors: dict[str, Any], logical_map: LogicalTokenMap, weight_state: WeightState, + rollout_mode: RolloutMode, global_grad_accumulation_sequences: int, forward_trace_capture: Any | None, forward_trace_dir: str | None, @@ -774,7 +864,7 @@ def _score_megatron_runtime( packed_tensors=packed_tensors, logical_map=logical_map, weight_state=weight_state, - rollout_mode="native_lora", + rollout_mode=rollout_mode, global_grad_accumulation_sequences=global_grad_accumulation_sequences, ) @@ -796,7 +886,7 @@ def _score_megatron_runtime( logical_map=logical_map, side="megatron", weight_state=weight_state, - rollout_mode="native_lora", + rollout_mode=rollout_mode, ) @@ -913,6 +1003,7 @@ def _configure_worker_bundle(bundle: Any) -> None: packed_tensors=cast(dict[str, Any], packed_tensors), logical_map=logical_map, weight_state=request.weight_state, + rollout_mode=_real_path_rollout_mode(request.config), global_grad_accumulation_sequences=request.global_grad_accumulation_sequences, forward_trace_capture=forward_trace_capture, forward_trace_dir=request.forward_trace_dir, @@ -1020,6 +1111,8 @@ async def run_real_path_train_inf_mismatch( from art.preprocessing.pack import packed_tensors_to_dir parity_config = config.output_parity + _apply_sliding_window_prompt_defaults(config) + rollout_mode = _real_path_rollout_mode(parity_config) is_moe = model_support_is_moe( parity_config.base_model, allow_unvalidated_arch=parity_config.allow_unvalidated_arch, @@ -1031,6 +1124,7 @@ async def run_real_path_train_inf_mismatch( if not adapter_path: raise RuntimeError("Real-path adapter worker did not create an adapter") + _init_art_megatron_runtime_config(parity_config) backend = MegatronBackend( path=str(artifact_dir / "art_path"), enable_expert_replay=is_moe, @@ -1040,10 +1134,15 @@ async def run_real_path_train_inf_mismatch( name=f"train-inf-real-{uuid.uuid4().hex[:8]}", project="train_inf_mismatch", base_model=parity_config.base_model, + lora_config=( + {"target_modules": _lora_target_modules(parity_config)} + if parity_config.lora_target_modules is not None + else None + ), _internal_config={ "trainer_gpu_ids": parity_config.trainer_gpu_ids, "inference_gpu_ids": parity_config.inference_gpu_ids, - "rollout_weights_mode": "lora", + "rollout_weights_mode": _real_path_rollout_weights_mode(parity_config), "allow_unvalidated_arch": parity_config.allow_unvalidated_arch, "engine_args": { "tensor_parallel_size": len(parity_config.inference_gpu_ids), @@ -1103,11 +1202,11 @@ async def run_real_path_train_inf_mismatch( cast(dict[str, Any], disk_packed_tensors), ) if is_moe: - routing_replay = packed_tensors["moe_routing_replay"] + routing_replay = cast( + PackedMoeRoutingReplay, packed_tensors["moe_routing_replay"] + ) stats = routing_replay.pack_stats else: - from art.preprocessing.moe_routing import MoeRoutingPackStats - stats = MoeRoutingPackStats() vllm_lora = _vllm_scores_from_real_choices( @@ -1115,6 +1214,7 @@ async def run_real_path_train_inf_mismatch( logical_map=logical_map, require_routing_metadata=is_moe, weight_state="lora", + rollout_mode=rollout_mode, ) _write_json( artifact_dir / "real_path_vllm_lora_scores.json", diff --git a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py index ee0c71828..d42c93871 100644 --- a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py +++ b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py @@ -21,8 +21,14 @@ compare_topk, config_from_env, fwd_mean_abs_pct_limit_for_model, + top20_kl_candidate_to_target_limit_for_model, +) +from .real_path import ( + RealPathConfig, + _delete_adapter_safetensors_on_pass, + _real_path_rollout_mode, + _real_path_rollout_weights_mode, ) -from .real_path import RealPathConfig, _delete_adapter_safetensors_on_pass def test_logical_map_flattens_shared_prefix_branches() -> None: @@ -131,6 +137,21 @@ def test_real_path_default_generates_16_tokens_per_rollout() -> None: assert RealPathConfig().max_completion_tokens == 16 +def test_real_path_rollout_mode_follows_config() -> None: + native_config = TrainInfOutputParityConfig( + base_model="Qwen/Qwen3.5-35B-A3B", + ) + merged_config = TrainInfOutputParityConfig( + base_model="unvalidated/native-disabled", + allow_unvalidated_arch=True, + ) + + assert _real_path_rollout_mode(native_config) == "native_lora" + assert _real_path_rollout_weights_mode(native_config) == "lora" + assert _real_path_rollout_mode(merged_config) == "merged" + assert _real_path_rollout_weights_mode(merged_config) == "merged" + + def test_real_path_deletes_only_adapter_safetensors_on_pass(tmp_path) -> None: run_dir = tmp_path / "run" active_lora = run_dir / "real_path_active_lora" @@ -160,6 +181,24 @@ def test_architecture_specific_real_path_limits() -> None: assert TOP20_KL_CANDIDATE_TO_TARGET_LIMIT == 0.002 +def test_gemma4_real_path_limits() -> None: + assert ( + fwd_mean_abs_pct_limit_for_model( + "google/gemma-4-26B-A4B-it", + allow_unvalidated_arch=True, + ) + == 8.0 + ) + assert ( + top20_kl_candidate_to_target_limit_for_model( + "google/gemma-4-26B-A4B-it", + allow_unvalidated_arch=True, + ) + == 0.008 + ) + assert TOP20_KL_CANDIDATE_TO_TARGET_LIMIT == 0.002 + + def test_compare_topk_reports_restricted_intersection_kl() -> None: target = ScoreBundle( side="megatron", @@ -239,8 +278,11 @@ def test_workflow_stage_enables_live_train_inf_mismatch( import subprocess captured_env = {} + real_run = workflow_stage.subprocess.run def fake_run(*args, **kwargs): + if "env" not in kwargs: + return real_run(*args, **kwargs) captured_env.update(kwargs["env"]) return subprocess.CompletedProcess( args=args, @@ -252,8 +294,12 @@ def fake_run(*args, **kwargs): monkeypatch.setattr(workflow_stage, "create_artifact_dir", lambda _nodeid: tmp_path) monkeypatch.setattr(workflow_stage.subprocess, "run", fake_run) - report = workflow_stage.run_train_inf_mismatch(base_model="Qwen/Qwen3.5-35B-A3B") + report = workflow_stage.run_train_inf_mismatch( + base_model="Qwen/Qwen3.5-35B-A3B", + allow_unvalidated_arch=True, + ) assert report.passed is True assert captured_env["ART_RUN_TRAIN_INF_MISMATCH_LIVE"] == "1" + assert captured_env["ART_TRAIN_INF_MISMATCH_ALLOW_UNVALIDATED_ARCH"] == "1" assert captured_env["ART_REAL_PATH_MAX_COMPLETION_TOKENS"] == "16" diff --git a/tests/integration/megatron/train_inf_mismatch/workflow_stage.py b/tests/integration/megatron/train_inf_mismatch/workflow_stage.py index 3977c3a94..ae0a7cef2 100644 --- a/tests/integration/megatron/train_inf_mismatch/workflow_stage.py +++ b/tests/integration/megatron/train_inf_mismatch/workflow_stage.py @@ -61,13 +61,20 @@ def _attempt_limit() -> int: return min(attempts, MAX_ATTEMPTS) -def run_train_inf_mismatch(*, base_model: str) -> TrainInfMismatchReport: +def run_train_inf_mismatch( + *, + base_model: str, + allow_unvalidated_arch: bool = False, +) -> TrainInfMismatchReport: artifact_dir = create_artifact_dir("workflow::train_inf_mismatch") max_attempts = _attempt_limit() env = os.environ.copy() env["BASE_MODEL"] = base_model env["ART_RUN_TRAIN_INF_MISMATCH_LIVE"] = "1" env["ART_TRAIN_INF_MISMATCH_BASE_MODEL"] = base_model + env["ART_TRAIN_INF_MISMATCH_ALLOW_UNVALIDATED_ARCH"] = ( + "1" if allow_unvalidated_arch else "0" + ) env["ART_REAL_PATH_MAX_COMPLETION_TOKENS"] = "16" existing_pythonpath = env.get("PYTHONPATH") tests_dir = str(REPO_ROOT / "tests") diff --git a/tests/integration/megatron/trainability/test_config.py b/tests/integration/megatron/trainability/test_config.py index 6004e9a9f..b265c2b3d 100644 --- a/tests/integration/megatron/trainability/test_config.py +++ b/tests/integration/megatron/trainability/test_config.py @@ -17,7 +17,6 @@ _variant_max_steps, _variant_packed_sequence_length, _variant_rollouts_per_prompt, - _variant_train_kwargs, ) @@ -106,7 +105,6 @@ def test_megatron_variants_keep_short_packed_sequence_default(monkeypatch) -> No ) assert _variant_packed_sequence_length(variant) == 1024 - assert _variant_train_kwargs(variant) == {"packed_sequence_length": 1024} config = _build_internal_config( variant, base_model="Qwen/Qwen3-30B-A3B-Instruct-2507" ) @@ -130,7 +128,6 @@ def test_unsloth_variant_uses_chunk_aligned_training_length(monkeypatch) -> None ) assert _variant_packed_sequence_length(variant) == 1024 - assert _variant_train_kwargs(variant) == {"packed_sequence_length": 1024} assert _variant_init_args(variant) == {"max_seq_length": 1024} assert _build_internal_config( variant, base_model="Qwen/Qwen3-30B-A3B-Instruct-2507" @@ -165,7 +162,7 @@ def test_validated_dense_model_uses_dense_shared_topology( base_model="Qwen/Qwen3.5-4B", ) assert built_variant.topology is not None - assert built_variant.topology.tp == 2 + assert built_variant.topology.cp == 2 assert built_variant.topology.ep == 1 assert built_variant.topology.etp == 1 diff --git a/tests/integration/megatron/trainability/test_live_length_trainability.py b/tests/integration/megatron/trainability/test_live_length_trainability.py new file mode 100644 index 000000000..32a8270bf --- /dev/null +++ b/tests/integration/megatron/trainability/test_live_length_trainability.py @@ -0,0 +1,714 @@ +from __future__ import annotations + +import asyncio +import json +import math +import os +from pathlib import Path +import random +import shutil +from typing import Any, AsyncIterator, Literal, cast +import uuid + +from pydantic import BaseModel, Field +import pytest + +import art +from art.megatron.model_support.registry import model_uses_expert_parallel +from art.pipeline_trainer import PipelineTrainer + +from ..model_support.oracle_harness import Topology +from .yes_no_trainability import ( + _backend_context, + _build_internal_config, + _build_variant, + _get_env_bool, + _get_env_float, + _get_env_int, + _init_megatron_runtime_config, + _list_model_ids, +) + +torch = pytest.importorskip("torch") + +DEFAULT_BASE_MODEL = "Qwen/Qwen3.5-35B-A3B" +LIVE_ENV = "ART_RUN_LIVE_LENGTH_TRAINABILITY" +TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" +INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" +REPO_ROOT = Path(__file__).resolve().parents[4] +LATEST_SUMMARY_LOG_PATH = REPO_ROOT / ".local" / "length_trainability.log" +INITIAL_ABS_ERROR_MIN = 5.0 +SUCCESS_ABS_ERROR_MAX = 1.5 +MOE_DEDICATED_TRAINING_TOPOLOGY = Topology( + tp=1, + cp=2, + ep=2, + etp=1, + dp=1, + sp=False, +) +BASE_PROMPT = ( + "Write a plain answer about a quiet harbor. Use the unrelated notes below " + "only as background texture. Use one sentence. Do not use bullets, numbering, " + "code, or a preface." +) +FILLER_SENTENCES = ( + "The morning ledger mentioned a bicycle bell near the old customs window.", + "A folded receipt waited beside three dull pencils and a chipped mug.", + "Someone had drawn a small square around Thursday on the calendar.", + "The storage room smelled faintly of rope, dust, and yesterday's rain.", + "A green notebook listed errands that no one seemed eager to finish.", + "The clock above the doorway ticked with a patient mechanical rhythm.", + "Two mismatched gloves rested under the bench near the umbrella stand.", + "A paper tag fluttered from a crate of spare brass hinges.", + "The shop radio murmured about traffic far from the waterfront.", + "A narrow envelope contained a map with several coffee stains.", + "The caretaker had stacked clean towels beside a basket of loose keys.", + "A faded poster advertised a lecture about practical knot repairs.", + "Someone left a blue scarf draped over the back of a wooden chair.", + "The rain gauge showed a modest line from a storm before dawn.", + "A quiet clerk sorted stamps into a tin marked for later use.", + "The window latch clicked softly whenever a colder breeze arrived.", + "A jar of buttons sat near the lamp with no label attached.", + "The floorboards held a faint shine where people usually turned left.", + "A postcard showed a bridge, though no bridge could be seen nearby.", + "The supply shelf included chalk, twine, soap, and several blank cards.", + "A small toolbox waited open with every socket arranged by size.", + "The notice board carried old schedules with careful handwritten corrections.", + "A kettle cooled on the counter beside a plate of plain biscuits.", + "The narrow hallway displayed framed photographs of ordinary cloudy afternoons.", + "A stack of forms leaned against a vase holding one dry reed.", + "The back office kept a spare lantern wrapped in brown paper.", + "A silver whistle hung from a nail beside the maintenance checklist.", + "The cupboard door closed unevenly unless pressed near the lower hinge.", + "A receipt book recorded purchases of candles, nails, and black ink.", + "The stair rail felt smooth where many hands had passed over it.", + "A shallow drawer contained string, labels, and a forgotten measuring tape.", + "The wall map used faded pins to mark unimportant delivery stops.", + "A wool cap lay on a crate beside a coil of clean line.", + "The afternoon light made the dust above the desk look almost orderly.", + "A clipboard noted that the north window should be painted soon.", + "The brass hook near the door held only an empty canvas bag.", + "A stack of newspapers waited under a stone used as a weight.", + "The broom leaned in a corner beside a cardboard box of washers.", + "A shallow bowl held wrapped peppermints for visitors who rarely arrived.", + "The gray filing cabinet opened with a scrape and a small sigh.", + "A pencil sharpener was screwed to the wall beside a crooked shelf.", + "The old ledger contained careful columns and very little useful drama.", + "A canvas cover protected the spare chair from dust and sunlight.", + "The side table held a ruler, a thimble, and a sealed jar.", + "A neat row of jars preserved screws sorted by uncertain categories.", + "The calendar showed local holidays in red and market days in blue.", + "A small bell above the entrance moved only when the door stuck.", + "The envelope tray was empty except for a note about lamp oil.", + "The desk drawer included a spare button and two brittle rubber bands.", + "A plain brown box carried the words archive later in pencil.", +) + + +class LengthScenario(BaseModel): + scenario_index: int + target_step: int + target_tokens: int + max_tokens: int + prompt: str + prompt_word_count: int + metadata: dict[str, int | float | str | None] = Field(default_factory=dict) + + +class LengthSampleReport(BaseModel): + split: Literal["train"] + step: int | None + scenario_index: int + target_step: int + target_tokens: int + max_tokens: int + prompt_word_count: int + generated_tokens: int + abs_error: int + reward: float + text: str + + +class LengthTrainabilityReport(BaseModel): + base_model: str + max_steps: int + max_steps_off_policy: int + latest_step: int + variant_name: str + trainer_gpu_ids: list[int] + inference_gpu_ids: list[int] + training_topology: dict[str, int | bool] + rollout_weights_mode: str + rollouts_per_prompt: int + normalize_advantages: bool + summary_log_path: str + latest_summary_log_path: str + initial_train_abs_error: float | None + best_train_abs_error: float | None + success_step: int | None + final_train_reward: float | None + final_train_abs_error: float | None + model_ids_after: list[str] + samples: list[LengthSampleReport] + + +def _require_opt_in() -> None: + if os.environ.get(LIVE_ENV) != "1": + pytest.skip(f"set {LIVE_ENV}=1 to run live length trainability") + + +def _base_model() -> str: + return os.environ.get( + "ART_LIVE_LENGTH_BASE_MODEL", + os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL), + ) + + +def _slugify(value: str) -> str: + return value.lower().replace("/", "_").replace(".", "_").replace("-", "_") + + +def _artifact_dir(base_model: str) -> Path: + path = ( + REPO_ROOT + / ".local" + / "model_support_validation" + / _slugify(base_model) + / "length_trainability" + ) + path.mkdir(parents=True, exist_ok=True) + return path + + +def _word_count(text: str) -> int: + return len(text.split()) + + +def _target_tokens() -> int: + return _get_env_int("ART_MODEL_SUPPORT_LENGTH_TARGET_TOKENS", 10) + + +def _use_default_moe_dedicated_placement(variant: Any, *, base_model: str) -> None: + if not model_uses_expert_parallel(base_model, allow_unvalidated_arch=True): + return + if os.environ.get(TRAINER_GPU_IDS_ENV) or os.environ.get(INFERENCE_GPU_IDS_ENV): + return + if torch.cuda.device_count() < 3: + pytest.skip( + "Need at least 3 visible CUDA GPUs for default dedicated MoE length " + "trainability: 2 trainer GPUs and 1 inference GPU." + ) + variant.trainer_gpu_ids = [0, 1] + variant.inference_gpu_ids = [2] + variant.topology = MOE_DEDICATED_TRAINING_TOPOLOGY + + +def _check_prompt_hides_target(prompt: str) -> None: + lowered = prompt.lower() + leaked = [ + phrase + for phrase in ("generated tokens", "target tokens", "target length", "exactly") + if phrase in lowered + ] + if leaked: + raise RuntimeError(f"Length prompt leaks target wording: {leaked}") + + +def _prompt_for_index(index: int) -> tuple[str, int]: + target_words = _get_env_int("ART_MODEL_SUPPORT_LENGTH_PROMPT_WORDS", 300) + rng = random.Random(index) + sentences = list(FILLER_SENTENCES) + rng.shuffle(sentences) + selected: list[str] = [] + prompt = BASE_PROMPT + for sentence in sentences: + if _word_count(prompt) >= target_words: + break + selected.append(sentence) + prompt = f"{BASE_PROMPT}\n\nNotes: {' '.join(selected)}" + _check_prompt_hides_target(prompt) + return prompt, _word_count(prompt) + + +def _scenario(index: int, *, target_step: int | None = None) -> LengthScenario: + target_tokens = _target_tokens() + max_tokens = max( + target_tokens + 1, + math.ceil( + target_tokens + * _get_env_float("ART_MODEL_SUPPORT_LENGTH_MAX_TOKENS_MULTIPLIER", 1.4) + ) + + 128, + ) + prompt, prompt_word_count = _prompt_for_index(index) + return LengthScenario( + scenario_index=index, + target_step=index if target_step is None else target_step, + target_tokens=target_tokens, + max_tokens=max_tokens, + prompt=prompt, + prompt_word_count=prompt_word_count, + metadata={ + "scenario_index": index, + "target_step": index if target_step is None else target_step, + "target_tokens": target_tokens, + "max_tokens": max_tokens, + "prompt_word_count": prompt_word_count, + }, + ) + + +def _step_from_model_name(model_name: str) -> int | None: + if "@" not in model_name: + return None + try: + return int(model_name.rsplit("@", 1)[1]) + except ValueError: + return None + + +def _scenario_for_training_step( + scenario: LengthScenario | dict[str, object], + step: int, +) -> LengthScenario: + parsed = LengthScenario.model_validate(scenario) + return _scenario(parsed.scenario_index, target_step=step) + + +def _messages(scenario: LengthScenario) -> art.Messages: + return [{"role": "user", "content": scenario.prompt}] + + +def _extra_body() -> dict[str, object]: + return {"chat_template_kwargs": {"enable_thinking": False}} + + +def _generated_token_count(choice: object) -> int: + logprobs = getattr(choice, "logprobs", None) + content = getattr(logprobs, "content", None) + if content is not None: + return len(content) + message = getattr(choice, "message", None) + return len((getattr(message, "content", "") or "").split()) + + +def _reward(generated_tokens: int, target_tokens: int) -> float: + # Do not clamp: early generations can be far from target, and CISPO still + # needs within-group reward differences to produce trainable advantages. + return -abs(generated_tokens - target_tokens) / max(1, target_tokens) + + +def _sample_report( + *, + split: Literal["train"], + step: int | None, + scenario: LengthScenario, + choice: object, +) -> LengthSampleReport: + generated_tokens = _generated_token_count(choice) + message = getattr(choice, "message", None) + text = getattr(message, "content", "") or "" + return LengthSampleReport( + split=split, + step=step, + scenario_index=scenario.scenario_index, + target_step=scenario.target_step, + target_tokens=scenario.target_tokens, + max_tokens=scenario.max_tokens, + prompt_word_count=scenario.prompt_word_count, + generated_tokens=generated_tokens, + abs_error=abs(generated_tokens - scenario.target_tokens), + reward=_reward(generated_tokens, scenario.target_tokens), + text=text, + ) + + +async def _length_group( + model: art.TrainableModel, + *, + scenario: LengthScenario, + model_name: str, + split: Literal["train"], + step: int | None, + n: int, + temperature: float, + samples: list[LengthSampleReport], + summary_log_path: Path | None = None, +) -> art.TrajectoryGroup: + messages = _messages(scenario) + completion = await model.openai_client().chat.completions.create( + messages=messages, + model=model_name, + max_tokens=scenario.max_tokens, + n=n, + temperature=temperature, + extra_body=_extra_body(), + logprobs=True, + top_logprobs=0, + timeout=_get_env_float("ART_MODEL_SUPPORT_LENGTH_REQUEST_TIMEOUT", 900.0), + ) + trajectories: list[art.Trajectory] = [] + for choice in completion.choices: + report = _sample_report( + split=split, + step=step, + scenario=scenario, + choice=choice, + ) + samples.append(report) + trajectories.append( + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=report.reward, + metrics={ + "length/generated_tokens": report.generated_tokens, + "length/target_tokens": scenario.target_tokens, + "length/max_tokens": scenario.max_tokens, + "length/prompt_word_count": scenario.prompt_word_count, + "length/abs_error": report.abs_error, + }, + metadata=scenario.metadata, + ) + ) + _append_step_summary(summary_log_path, samples, split=split, step=step) + return art.TrajectoryGroup(trajectories) + + +def _mean_reward(samples: list[LengthSampleReport]) -> float: + return sum(sample.reward for sample in samples) / max(1, len(samples)) + + +def _mean(values: list[float]) -> float: + return sum(values) / max(1, len(values)) + + +def _mean_abs_error_by_step(samples: list[LengthSampleReport]) -> dict[int, float]: + steps = sorted({sample.step for sample in samples if sample.step is not None}) + return { + step: _mean( + [float(sample.abs_error) for sample in samples if sample.step == step] + ) + for step in steps + } + + +def _init_summary_log(path: Path) -> None: + path.write_text( + "\n".join( + ( + "# length trainability summary", + "# rows append when a rollout/eval group completes; n is cumulative for split+step", + ( + "split step target max_tok prompt_w n reward_mean " + "gen_mean abs_err_mean gen_min gen_max reward_min reward_max" + ), + ) + ) + + "\n", + encoding="utf-8", + ) + _copy_latest_summary_log(path) + + +def _copy_latest_summary_log(path: Path) -> None: + LATEST_SUMMARY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(path, LATEST_SUMMARY_LOG_PATH) + + +def _append_step_summary( + path: Path | None, + samples: list[LengthSampleReport], + *, + split: Literal["train"], + step: int | None, +) -> None: + if path is None: + return + matching = [ + sample for sample in samples if sample.split == split and sample.step == step + ] + if not matching: + return + generated = [float(sample.generated_tokens) for sample in matching] + abs_errors = [float(sample.abs_error) for sample in matching] + rewards = [sample.reward for sample in matching] + latest = matching[-1] + with path.open("a", encoding="utf-8") as handle: + handle.write( + f"{split:<9} {step if step is not None else '-':>4} " + f"{latest.target_tokens:>6} {latest.max_tokens:>7} " + f"{latest.prompt_word_count:>8} {len(matching):>5} " + f"{_mean(rewards):>11.4f} {_mean(generated):>8.1f} " + f"{_mean(abs_errors):>12.1f} {int(min(generated)):>7} " + f"{int(max(generated)):>7} {min(rewards):>10.4f} " + f"{max(rewards):>10.4f}\n" + ) + _copy_latest_summary_log(path) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="Need at least 3 CUDA GPUs for live dedicated length trainability", +) +@pytest.mark.asyncio +async def test_megatron_dedicated_length_trainability_live(artifact_dir: Path) -> None: + _require_opt_in() + report = await run_length_trainability_async( + base_model=_base_model(), + artifact_dir=artifact_dir, + allow_unvalidated_arch=True, + ) + assert_length_trainability_passed(report) + + +async def run_length_trainability_async( + *, + base_model: str = DEFAULT_BASE_MODEL, + artifact_dir: Path | None = None, + allow_unvalidated_arch: bool = False, +) -> LengthTrainabilityReport: + artifact_dir = artifact_dir or _artifact_dir(base_model) + variant = _build_variant( + "megatron_dedicated", + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + _use_default_moe_dedicated_placement(variant, base_model=base_model) + max_steps = _get_env_int("ART_MODEL_SUPPORT_LENGTH_MAX_STEPS", 10) + max_steps_off_policy = _get_env_int( + "ART_MODEL_SUPPORT_LENGTH_MAX_STEPS_OFF_POLICY", + 0, + ) + rollouts_per_prompt = _get_env_int( + "ART_MODEL_SUPPORT_LENGTH_ROLLOUTS_PER_PROMPT", + 4, + ) + normalize_advantages = _get_env_bool( + "ART_MODEL_SUPPORT_LENGTH_NORMALIZE_ADVANTAGES", + True, + ) + rollout_workers = _get_env_int( + "ART_MODEL_SUPPORT_LENGTH_ROLLOUT_WORKERS", + max(1, max_steps_off_policy + 1), + ) + scenario_count = _get_env_int( + "ART_MODEL_SUPPORT_LENGTH_SCENARIOS", + max_steps * max(rollouts_per_prompt, 2) + rollout_workers + 4, + ) + success_hit = False + samples: list[LengthSampleReport] = [] + backend_root = artifact_dir / "megatron_dedicated_workspace" + summary_log_path = artifact_dir / "length_trainability.log" + _init_summary_log(summary_log_path) + internal_config = _build_internal_config( + variant, + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + internal_config["engine_args"]["max_model_len"] = _get_env_int( + "ART_MODEL_SUPPORT_LENGTH_MAX_MODEL_LEN", + 1024, + ) + internal_config["engine_args"]["max_num_seqs"] = _get_env_int( + "ART_MODEL_SUPPORT_LENGTH_MAX_NUM_SEQS", + 4, + ) + rollout_weights_mode = internal_config["rollout_weights_mode"] + _init_megatron_runtime_config(variant) + + async with _backend_context(variant, backend_root=backend_root) as backend: + model = art.TrainableModel( + name=f"length-{uuid.uuid4().hex[:8]}", + project="integration-tests", + base_model=base_model, + _internal_config=internal_config, + report_metrics=[], + ) + await model.register(backend) + + async def scenarios() -> AsyncIterator[dict[str, object]]: + for index in range(scenario_count): + if success_hit: + break + yield _scenario(index, target_step=0).model_dump() + + async def rollout_fn( + rollout_model: art.TrainableModel, + scenario: dict[str, object], + _config: None, + ) -> art.TrajectoryGroup: + nonlocal success_hit + model_name = rollout_model.get_inference_name() + target_step = _step_from_model_name(model_name) + if target_step is None: + target_step = await rollout_model.get_step() + group = await _length_group( + rollout_model, + scenario=_scenario_for_training_step(scenario, target_step), + model_name=model_name, + split="train", + step=target_step, + n=rollouts_per_prompt, + temperature=_get_env_float( + "ART_MODEL_SUPPORT_LENGTH_ROLLOUT_TEMPERATURE", + 1.1, + ), + samples=samples, + summary_log_path=summary_log_path, + ) + if ( + _mean_abs_error_by_step( + [sample for sample in samples if sample.split == "train"] + )[target_step] + <= SUCCESS_ABS_ERROR_MAX + ): + success_hit = True + return group + + trainer = PipelineTrainer( + model=model, + backend=backend, + rollout_fn=rollout_fn, + scenarios=scenarios(), + config=None, + num_rollout_workers=rollout_workers, + min_batch_size=1, + max_batch_size=1, + max_steps_off_policy=max_steps_off_policy, + learning_rate=_get_env_float( + "ART_MODEL_SUPPORT_LENGTH_LEARNING_RATE", + 1e-4, + ), + loss_fn="cispo", + normalize_advantages=normalize_advantages, + max_steps=max_steps, + eval_every_n_steps=0, + eval_at_start=False, + save_checkpoint=False, + total_scenarios=scenario_count, + log_interval_seconds=30.0, + discard_queue_multiplier=1000, + resume=False, + ) + await trainer.train(handle_signals=False) + + latest_step = await model.get_step() + model_ids_after = await _list_model_ids(model) + + train_samples = [sample for sample in samples if sample.split == "train"] + train_rewards_by_step = { + step: [sample.reward for sample in train_samples if sample.step == step] + for step in {sample.step for sample in train_samples} + } + train_abs_error_by_step = _mean_abs_error_by_step(train_samples) + initial_train_abs_error = train_abs_error_by_step.get(0) + best_train_abs_error = ( + min(train_abs_error_by_step.values()) if train_abs_error_by_step else None + ) + success_step = next( + ( + step + for step, abs_error in train_abs_error_by_step.items() + if abs_error <= SUCCESS_ABS_ERROR_MAX + ), + None, + ) + final_train_samples = [ + sample for sample in train_samples if sample.step == latest_step - 1 + ] + final_train_reward = ( + _mean_reward(final_train_samples) if final_train_samples else None + ) + final_train_abs_error = ( + _mean([float(sample.abs_error) for sample in final_train_samples]) + if final_train_samples + else None + ) + topology = cast(Topology, variant.topology) + report = LengthTrainabilityReport( + base_model=base_model, + max_steps=max_steps, + max_steps_off_policy=max_steps_off_policy, + latest_step=latest_step, + variant_name=variant.name, + trainer_gpu_ids=variant.trainer_gpu_ids, + inference_gpu_ids=variant.inference_gpu_ids, + training_topology=cast(dict[str, int | bool], topology.model_dump()), + rollout_weights_mode=rollout_weights_mode, + rollouts_per_prompt=rollouts_per_prompt, + normalize_advantages=normalize_advantages, + summary_log_path=str(summary_log_path), + latest_summary_log_path=str(LATEST_SUMMARY_LOG_PATH), + initial_train_abs_error=initial_train_abs_error, + best_train_abs_error=best_train_abs_error, + success_step=success_step, + final_train_reward=final_train_reward, + final_train_abs_error=final_train_abs_error, + model_ids_after=model_ids_after, + samples=samples, + ) + (artifact_dir / "length_trainability.json").write_text( + json.dumps(report.model_dump(mode="json"), indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + return report + + +def run_length_trainability( + *, + base_model: str = DEFAULT_BASE_MODEL, + allow_unvalidated_arch: bool = False, +) -> LengthTrainabilityReport: + return asyncio.run( + run_length_trainability_async( + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + ) + + +def length_trainability_passed(report: LengthTrainabilityReport) -> bool: + train_samples = [sample for sample in report.samples if sample.split == "train"] + train_rewards_by_step = { + step: [sample.reward for sample in train_samples if sample.step == step] + for step in {sample.step for sample in train_samples} + } + return ( + bool(train_samples) + and report.latest_step <= report.max_steps + and report.initial_train_abs_error is not None + and report.initial_train_abs_error >= INITIAL_ABS_ERROR_MIN + and report.best_train_abs_error is not None + and report.best_train_abs_error <= SUCCESS_ABS_ERROR_MAX + and report.success_step is not None + and len(train_rewards_by_step) <= report.max_steps + and all(sample.max_tokens > sample.target_tokens for sample in train_samples) + and any(sample.generated_tokens < sample.max_tokens for sample in train_samples) + and any(len(set(rewards)) > 1 for rewards in train_rewards_by_step.values()) + and any( + name.endswith(f"@{report.latest_step}") for name in report.model_ids_after + ) + ) + + +def assert_length_trainability_passed(report: LengthTrainabilityReport) -> None: + train_samples = [sample for sample in report.samples if sample.split == "train"] + train_rewards_by_step = { + step: [sample.reward for sample in train_samples if sample.step == step] + for step in {sample.step for sample in train_samples} + } + assert train_samples + assert report.latest_step <= report.max_steps + assert report.initial_train_abs_error is not None + assert report.initial_train_abs_error >= INITIAL_ABS_ERROR_MIN + assert report.best_train_abs_error is not None + assert report.best_train_abs_error <= SUCCESS_ABS_ERROR_MAX + assert report.success_step is not None + assert len(train_rewards_by_step) <= report.max_steps + assert all(sample.max_tokens > sample.target_tokens for sample in train_samples) + assert any(sample.generated_tokens < sample.max_tokens for sample in train_samples) + assert any(len(set(rewards)) > 1 for rewards in train_rewards_by_step.values()) + assert any( + name.endswith(f"@{report.latest_step}") for name in report.model_ids_after + ) diff --git a/tests/integration/megatron/trainability/yes_no_trainability.py b/tests/integration/megatron/trainability/yes_no_trainability.py index 46675535e..a3f7918ae 100644 --- a/tests/integration/megatron/trainability/yes_no_trainability.py +++ b/tests/integration/megatron/trainability/yes_no_trainability.py @@ -8,7 +8,7 @@ from pathlib import Path import re import time -from typing import Any, AsyncIterator, Iterator, Literal, TypedDict, cast +from typing import Any, AsyncIterator, Iterator, Literal, cast import uuid from pydantic import BaseModel, Field @@ -42,10 +42,6 @@ ] -class _TrainKwargs(TypedDict): - packed_sequence_length: int - - class TrainabilityStepReport(BaseModel): step: int eval_reward: float @@ -286,6 +282,7 @@ def _engine_args_for_yes_no_trainability( "max_num_seqs": _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_NUM_SEQS", 4), "enforce_eager": True, "tensor_parallel_size": tensor_parallel_size, + "limit_mm_per_prompt": {"image": 0, "video": 0, "audio": 0}, } if enable_expert_parallel: engine_args["enable_expert_parallel"] = True @@ -379,14 +376,24 @@ def _variant_packed_sequence_length(variant: _TrainabilityVariant) -> int: return _get_env_int("ART_MODEL_SUPPORT_YES_NO_PACKED_SEQUENCE_LENGTH", 1024) -def _variant_train_kwargs(variant: _TrainabilityVariant) -> _TrainKwargs: - return {"packed_sequence_length": _variant_packed_sequence_length(variant)} - - def _variant_init_args(variant: _TrainabilityVariant) -> dev.InitArgs: return {"max_seq_length": _variant_packed_sequence_length(variant)} +def _init_megatron_runtime_config(variant: _TrainabilityVariant) -> None: + if variant.topology is None: + return + art.init_megatron_runtime_config( + topology=art.MegatronTopologyConfig( + tp=variant.topology.tp, + cp=variant.topology.cp, + ep=variant.topology.ep, + etp=variant.topology.etp, + ), + packed_sequence_length=_variant_packed_sequence_length(variant), + ) + + def _variant_max_steps(variant: _TrainabilityVariant) -> int: default = 12 if variant.backend_name == "local" else 4 return _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_STEPS", default) @@ -689,6 +696,7 @@ async def run_yes_no_trainability_async( allow_unvalidated_arch=allow_unvalidated_arch, ) rollout_weights_mode = internal_config["rollout_weights_mode"] + _init_megatron_runtime_config(variant) model = art.TrainableModel( name=f"{variant.name}-{uuid.uuid4().hex[:8]}", project="model-support-validation", @@ -696,8 +704,6 @@ async def run_yes_no_trainability_async( _internal_config=internal_config, report_metrics=[], ) - train_kwargs = _variant_train_kwargs(variant) - async with _backend_context( variant, backend_root=backend_root, extra_env=extra_env ) as backend: @@ -750,7 +756,6 @@ async def run_yes_no_trainability_async( 1e-4, ), loss_fn="cispo", - packed_sequence_length=train_kwargs["packed_sequence_length"], ) await model.log( train_groups, diff --git a/tests/support/chat_template_conformance_cases.py b/tests/support/chat_template_conformance_cases.py index b39d8f8d0..5912c5784 100644 --- a/tests/support/chat_template_conformance_cases.py +++ b/tests/support/chat_template_conformance_cases.py @@ -7,7 +7,8 @@ from pydantic import BaseModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from art.trajectories import History, Trajectory, TrajectoryGroup +from art.preprocessing.tokenize import _apply_chat_template_token_ids +from art.trajectories import History, Trajectory, TrajectoryGroup, get_messages from art.types import MessagesAndChoices, Tools @@ -76,7 +77,7 @@ def _choice_for_text( "refusal": None, }, "message": { - "content": text, + "content": "" if tool_calls else text, "refusal": None, "role": "assistant", "annotations": None, @@ -88,6 +89,85 @@ def _choice_for_text( ) +def _logprob_content(token_ids: list[int]) -> list[dict[str, Any]]: + return [ + { + "token": f"token_id:{token_id}", + "bytes": list(str(token_id).encode("utf-8")), + "logprob": -0.1, + "top_logprobs": [], + } + for token_id in token_ids + ] + + +def _choice_with_token_metadata( + choice: Choice, + *, + prompt_token_ids: list[int], + completion_token_ids: list[int], +) -> Choice: + payload = choice.model_dump(mode="python") + payload["logprobs"]["content"] = _logprob_content(completion_token_ids) + payload["prompt_token_ids"] = prompt_token_ids + payload["token_ids"] = completion_token_ids + return Choice.model_validate(payload) + + +def _rendered_ids( + tokenizer: PreTrainedTokenizerBase, + messages_and_choices: MessagesAndChoices, + tools: Tools | None, +) -> list[int]: + return _apply_chat_template_token_ids( + tokenizer, + cast(list[dict[str, Any]], get_messages(messages_and_choices)), + tools=tools, + tokenize=True, + add_generation_prompt=False, + ) + + +def _attach_token_metadata_to_history( + tokenizer: PreTrainedTokenizerBase, + history: Trajectory | History, +) -> None: + items = history.messages_and_choices + for index, item in enumerate(items): + if not isinstance(item, Choice): + continue + prompt_token_ids = _rendered_ids(tokenizer, items[:index], history.tools) + rendered_ids = _rendered_ids(tokenizer, items[: index + 1], history.tools) + completion_token_ids = rendered_ids[len(prompt_token_ids) :] + items[index] = _choice_with_token_metadata( + item, + prompt_token_ids=prompt_token_ids, + completion_token_ids=completion_token_ids, + ) + + +def _attach_token_metadata( + tokenizer: PreTrainedTokenizerBase, + inputs: "ChatTemplateConformanceInputs", +) -> "ChatTemplateConformanceInputs": + groups = ( + inputs.text_pack_group, + inputs.tool_conversation_group, + inputs.additional_histories_group, + ) + trajectories = [ + inputs.non_final_tool_call_base, + inputs.non_final_tool_call_mutated, + inputs.unsupported_assistant_tool_calls, + *(trajectory for group in groups for trajectory in group.trajectories), + ] + for trajectory in trajectories: + _attach_token_metadata_to_history(tokenizer, trajectory) + for history in trajectory.additional_histories: + _attach_token_metadata_to_history(tokenizer, history) + return inputs + + def _messages_and_choices(*items: Any) -> MessagesAndChoices: return cast(MessagesAndChoices, list(items)) @@ -115,7 +195,7 @@ def build_chat_template_conformance_inputs( tools = _tool_schema() - return ChatTemplateConformanceInputs( + inputs = ChatTemplateConformanceInputs( text_pack_group=TrajectoryGroup( [ Trajectory( @@ -278,3 +358,4 @@ def build_chat_template_conformance_inputs( tools=tools, ), ) + return _attach_token_metadata(tokenizer, inputs) diff --git a/tests/unit/test_dedicated_config.py b/tests/unit/test_dedicated_config.py index adb3cbe72..5f09cfbce 100644 --- a/tests/unit/test_dedicated_config.py +++ b/tests/unit/test_dedicated_config.py @@ -9,6 +9,14 @@ from art.dev.validate import is_dedicated_mode, validate_dedicated_config +@pytest.fixture(autouse=True) +def _stub_model_max_seq_length(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "art.dev.get_model_config.max_seq_length_from_model_config", + lambda *_args, **_kwargs: 2048, + ) + + def test_shared_mode_empty_config(): config = InternalModelConfig() assert is_dedicated_mode(config) is False diff --git a/tests/unit/test_model_openai_client_costs.py b/tests/unit/test_model_openai_client_costs.py index 60d25adf3..0c6292c6a 100644 --- a/tests/unit/test_model_openai_client_costs.py +++ b/tests/unit/test_model_openai_client_costs.py @@ -204,8 +204,10 @@ def test_trainable_model_uses_configured_chat_template_kwargs(self) -> None: ) assert model._default_chat_completion_extra_body() == { + "return_token_ids": True, + "return_tokens_as_token_ids": True, "chat_template_kwargs": { "enable_thinking": False, "preserve_thinking": True, - } + }, } diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index ab0d5765e..f679d45d6 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -9,12 +9,10 @@ import pytest import torch -from transformers.tokenization_utils_base import PreTrainedTokenizerBase from art import TrainableModel, Trajectory, TrajectoryGroup from art.dev.model import InternalModelConfig from art.local import LocalBackend -from art.megatron import MegatronBackend from art.megatron.train import load_adapter_into_model from art.pipeline_trainer import ( CHECKPOINT_CREATED_AT_METRIC, @@ -22,7 +20,6 @@ CheckpointRetentionContext, ) from art.pipeline_trainer.trainer import PipelineTrainer -from art.preprocessing.tokenize import TokenizedResult from art.utils.output_dirs import get_model_dir, get_step_checkpoint_dir @@ -101,33 +98,6 @@ async def test_pipeline_trainer_preserves_backend_train_kwargs(tmp_path: Path) - } -@pytest.mark.asyncio -async def test_pipeline_trainer_forwards_packed_sequence_length_when_set( - tmp_path: Path, -) -> None: - model = TrainableModel( - name="pipeline-packed-sequence-length", - project="pipeline-tests", - base_model="test-model", - base_path=str(tmp_path), - ) - backend = MagicMock() - backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) - - trainer = _make_trainer( - model=model, - backend=backend, - packed_sequence_length=4096, - ) - trainer._output_queue = asyncio.Queue() - await trainer._output_queue.put(_make_group([0.0, 1.0])) - await trainer._output_queue.put(None) - - await trainer._training_stage() - - assert backend.train.await_args.kwargs["packed_sequence_length"] == 4096 - - @pytest.mark.asyncio async def test_pipeline_trainer_forwards_default_kl_step_zero_for_generic_backend( tmp_path: Path, @@ -675,112 +645,6 @@ async def test_pipeline_trainer_logs_checkpoint_retention_metadata( assert rows[1][CHECKPOINT_EVAL_COMPLETED_METRIC] == 1.0 -def _make_tokenized_result( - trajectory: Trajectory, - token_ids: list[int], -) -> TokenizedResult: - tokenizer = cast( - PreTrainedTokenizerBase, - SimpleNamespace(eos_token_id=0, decode=lambda token_id: str(token_id)), - ) - return TokenizedResult( - advantage=1.0, - chat="", - token_ids=token_ids, - input_pos=list(range(len(token_ids))), - assistant_mask=[0] * (len(token_ids) - 1) + [1], - logprobs=[float("nan")] * (len(token_ids) - 1) + [-0.1], - pixel_values=None, - image_grid_thw=None, - trajectory=trajectory, - choice_offsets=[], - extra_logprobs={}, - _tokenizer=tokenizer, - weight=1.0, - prompt_id=123, - prompt_length=1, - ) - - -def test_local_backend_get_packed_tensors_warns_and_drops_overlong_results( - tmp_path: Path, -) -> None: - backend = LocalBackend(path=str(tmp_path)) - model = TrainableModel( - name="local-backend-packed-sequence-length", - project="pipeline-tests", - base_model="test-model", - base_path=str(tmp_path), - ) - short_trajectory = Trajectory( - reward=1.0, - initial_policy_version=0, - messages_and_choices=[ - {"role": "user", "content": "short"}, - {"role": "assistant", "content": "answer"}, - ], - ) - long_trajectory = Trajectory( - reward=1.0, - initial_policy_version=0, - messages_and_choices=[ - {"role": "user", "content": "long"}, - {"role": "assistant", "content": "answer"}, - ], - ) - short_result = _make_tokenized_result(short_trajectory, [1, 2, 3, 4]) - long_result = _make_tokenized_result(long_trajectory, list(range(10))) - - with ( - patch( - "art.local.backend.AutoTokenizer.from_pretrained", - return_value=short_result._tokenizer, - ), - patch("transformers.AutoImageProcessor.from_pretrained", return_value=None), - patch( - "art.local.backend.tokenize_trajectory_groups", - return_value=iter([short_result, long_result]), - ), - pytest.warns(UserWarning, match="Dropping 1 tokenized results"), - ): - packed_tensors = backend._get_packed_tensors( - model, - [_make_group([0.0, 1.0])], - advantage_balance=0.0, - allow_training_without_logprobs=False, - scale_rewards=True, - plot_tensors=False, - packed_sequence_length=4, - logprob_calculation_chunk_size=2, - ) - - assert packed_tensors is not None - assert packed_tensors["tokens"].shape == (1, 4) - - -@pytest.mark.asyncio -async def test_megatron_backend_train_requires_packed_sequence_length( - tmp_path: Path, -) -> None: - model = TrainableModel( - name="megatron-backend-packed-sequence-length", - project="pipeline-tests", - base_model="test-model", - base_path=str(tmp_path), - ) - backend = MegatronBackend(path=str(tmp_path)) - - with patch.object(model, "_get_wandb_run", return_value=None): - with pytest.raises( - ValueError, match="MegatronBackend\\.train requires packed_sequence_length" - ): - await backend.train( - model, - [_make_group([1.0])], - save_checkpoint=False, - ) - - def test_load_adapter_into_model_reloads_optimizer_when_provided() -> None: class FakeModule(torch.nn.Module): def __init__(self) -> None: diff --git a/tests/unit/test_preprocessing_tokenize.py b/tests/unit/test_preprocessing_tokenize.py index 587207b30..e28f3ac75 100644 --- a/tests/unit/test_preprocessing_tokenize.py +++ b/tests/unit/test_preprocessing_tokenize.py @@ -1,29 +1,18 @@ -import importlib -import sys from typing import Any, cast -from openai.types.chat.chat_completion import Choice import pytest from transformers.tokenization_utils_base import BatchEncoding -from art.preprocessing.tokenize import tokenize_sft_batch, tokenize_trajectory -from art.trajectories import History, Trajectory +from art.preprocessing.tokenize import tokenize_sft_batch +from art.trajectories import Trajectory from art.types import MessagesAndChoices -if "tests" not in sys.path: - sys.path.insert(0, "tests") - -build_chat_template_conformance_inputs = importlib.import_module( - "support.chat_template_conformance_cases" -).build_chat_template_conformance_inputs - pytest.importorskip("torch") pytest.importorskip("transformers") class _FakeTokenizer: chat_template = "" - vocab_size = 256 eos_token = "\x00" eos_token_id = 0 @@ -86,126 +75,6 @@ def convert_tokens_to_ids(self, tokens): return self.eos_token_id -class _Qwen3_5FakeTokenizer(_FakeTokenizer): - chat_template = ( - "{% for args_name, args_value in tool_call.arguments|items %}{% endfor %}" - ) - - def apply_chat_template( - self, - messages, - tools=None, - tokenize=True, - return_dict=None, - **kwargs, - ): - for message in messages: - tool_calls = message.get("tool_calls") - if tool_calls is None: - continue - assert isinstance(tool_calls, list) - for tool_call in tool_calls: - assert isinstance(tool_call, dict) - function = tool_call["function"] - assert isinstance(function, dict) - assert isinstance(function["arguments"], dict) - return super().apply_chat_template( - messages, - tools=tools, - tokenize=tokenize, - return_dict=return_dict, - **kwargs, - ) - - -class _ContinueFinalMessageRejectingTokenizer(_FakeTokenizer): - def apply_chat_template( - self, - messages, - tools=None, - tokenize=True, - return_dict=None, - **kwargs, - ): - if kwargs.get("continue_final_message") is True and messages[-1].get( - "content", "" - ).startswith(""): - raise ValueError( - "continue_final_message is set but the final message does not appear " - "in the chat after applying the chat template!" - ) - return super().apply_chat_template( - messages, - tools=tools, - tokenize=tokenize, - return_dict=return_dict, - **kwargs, - ) - - -def test_tokenize_trajectory_accepts_batchencoding_chat_template_output() -> None: - tokenizer = _FakeTokenizer() - messages = cast( - MessagesAndChoices, - [ - {"role": "user", "content": "Hi"}, - {"role": "assistant", "content": "OK"}, - ], - ) - history = History(messages_and_choices=messages) - trajectory = Trajectory(messages_and_choices=messages, reward=1.0) - - result = tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=history, - advantage=1.0, - allow_training_without_logprobs=True, - trajectory=trajectory, - ) - - assert result is not None - assistant_ids = [ - token_id - for token_id, mask in zip(result.token_ids, result.assistant_mask) - if mask - ] - assert assistant_ids == tokenizer.encode("OK", add_special_tokens=False) - - -def test_tokenize_trajectory_passes_chat_template_kwargs() -> None: - tokenizer = _FakeTokenizer() - messages = cast( - MessagesAndChoices, - [ - {"role": "user", "content": "Hi"}, - {"role": "assistant", "content": "OK"}, - ], - ) - history = History(messages_and_choices=messages) - trajectory = Trajectory(messages_and_choices=messages, reward=1.0) - - result = tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=history, - advantage=1.0, - allow_training_without_logprobs=True, - trajectory=trajectory, - chat_template_kwargs={ - "enable_thinking": False, - "preserve_thinking": True, - }, - ) - - assert result is not None - assert tokenizer.apply_chat_template_kwargs - assert all( - call.get("enable_thinking") is False and call.get("preserve_thinking") is True - for call in tokenizer.apply_chat_template_kwargs - ) - - def test_tokenize_sft_batch_masks_response_tokens_without_unsloth_import() -> None: tokenizer = _FakeTokenizer() messages = cast( @@ -230,253 +99,27 @@ def test_tokenize_sft_batch_masks_response_tokens_without_unsloth_import() -> No assert batch.num_trainable_tokens == 2 -def test_tokenize_trajectory_does_not_continue_real_completion_with_thinking() -> None: - tokenizer = _ContinueFinalMessageRejectingTokenizer() - choice = Choice.model_validate( - { - "finish_reason": "stop", - "index": 0, - "logprobs": { - "content": [ - { - "token": "token_id:79", - "bytes": [79], - "logprob": -0.1, - "top_logprobs": [], - }, - { - "token": "token_id:75", - "bytes": [75], - "logprob": -0.2, - "top_logprobs": [], - }, - ], - "refusal": None, - }, - "message": { - "content": "\nreasoning\n\n\nOK", - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": None, - }, - } - ) +def test_tokenize_sft_batch_passes_chat_template_kwargs() -> None: + tokenizer = _FakeTokenizer() messages = cast( MessagesAndChoices, [ {"role": "user", "content": "Hi"}, - choice, + {"role": "assistant", "content": "OK"}, ], ) - history = History(messages_and_choices=messages) - trajectory = Trajectory(messages_and_choices=messages, reward=1.0) - result = tokenize_trajectory( + tokenize_sft_batch( + trajectory_batch=[Trajectory(messages_and_choices=messages, reward=1.0)], + learning_rate=1e-5, tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=history, - advantage=1.0, - allow_training_without_logprobs=False, - trajectory=trajectory, + instruction_part="", + response_part="", chat_template_kwargs={ "enable_thinking": False, "preserve_thinking": True, }, ) - assert result is not None - assistant_ids = [ - token_id - for token_id, mask in zip(result.token_ids, result.assistant_mask) - if mask - ] - assert assistant_ids == [79, 75] - continue_values = [ - call.get("continue_final_message") - for call in tokenizer.apply_chat_template_kwargs - ] - assert continue_values[:2] == [False, False] - assert continue_values[-1] is True - - -def test_tokenize_trajectory_normalizes_mapping_tool_arguments_for_chat_template() -> ( - None -): - tokenizer = _Qwen3_5FakeTokenizer() - choice = Choice.model_validate( - { - "finish_reason": "stop", - "index": 0, - "logprobs": { - "content": [ - { - "token": "token_id:65", - "bytes": [65], - "logprob": -0.1, - "top_logprobs": [], - } - ], - "refusal": None, - }, - "message": { - "content": "", - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": [ - { - "id": "call_1", - "function": { - "arguments": '{"city": "San Francisco", "days": 3}', - "name": "lookup_weather", - }, - "type": "function", - } - ], - }, - } - ) - messages = cast( - MessagesAndChoices, - [ - {"role": "user", "content": "Weather?"}, - choice, - ], - ) - history = History(messages_and_choices=messages) - trajectory = Trajectory(messages_and_choices=messages, reward=1.0) - - result = tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=history, - advantage=1.0, - allow_training_without_logprobs=False, - trajectory=trajectory, - ) - - assert result is not None - - -def test_tokenize_trajectory_uses_exact_tokens_for_malformed_final_tool_call() -> None: - tokenizer = _Qwen3_5FakeTokenizer() - choice = Choice.model_validate( - { - "finish_reason": "tool_calls", - "index": 0, - "logprobs": { - "content": [ - { - "token": "token_id:65", - "bytes": [65], - "logprob": -0.1, - "top_logprobs": [], - } - ], - "refusal": None, - }, - "message": { - "content": "prefix", - "refusal": None, - "role": "assistant", - "annotations": None, - "audio": None, - "function_call": None, - "tool_calls": [ - { - "id": "call_1", - "function": { - "arguments": '{"offer_id": None}', - "name": "create_booking", - }, - "type": "function", - } - ], - }, - } - ) - messages = cast( - MessagesAndChoices, - [ - {"role": "user", "content": "Book it."}, - choice, - ], - ) - result = tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=History(messages_and_choices=messages), - advantage=1.0, - allow_training_without_logprobs=False, - trajectory=Trajectory(messages_and_choices=messages, reward=1.0), - ) - - assert result is not None - assistant_ids = [ - token_id - for token_id, mask in zip(result.token_ids, result.assistant_mask) - if mask - ] - assert assistant_ids == [65] - - -def test_tokenize_trajectory_non_final_tool_call_mutation_changes_prefill_tokens() -> ( - None -): - tokenizer = _Qwen3_5FakeTokenizer() - inputs = build_chat_template_conformance_inputs(tokenizer) # type: ignore[arg-type] - - base = tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=History( - messages_and_choices=inputs.non_final_tool_call_base.messages_and_choices, - tools=inputs.non_final_tool_call_base.tools, - ), - advantage=1.0, - allow_training_without_logprobs=False, - trajectory=inputs.non_final_tool_call_base, - ) - mutated = tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=History( - messages_and_choices=inputs.non_final_tool_call_mutated.messages_and_choices, - tools=inputs.non_final_tool_call_mutated.tools, - ), - advantage=1.0, - allow_training_without_logprobs=False, - trajectory=inputs.non_final_tool_call_mutated, - ) - - assert base is not None - assert mutated is not None - assert len(base.choice_offsets) >= 2 - assert len(mutated.choice_offsets) >= 2 - assert ( - base.token_ids[: base.choice_offsets[-1]] - != mutated.token_ids[: mutated.choice_offsets[-1]] - ) - - -def test_tokenize_trajectory_rejects_assistant_tool_calls_without_logprobs() -> None: - tokenizer = _Qwen3_5FakeTokenizer() - inputs = build_chat_template_conformance_inputs(tokenizer) # type: ignore[arg-type] - - with pytest.raises(ValueError, match="Assistant message has tool_calls"): - tokenize_trajectory( - tokenizer=tokenizer, # type: ignore[arg-type] - image_processor=None, - history=History( - messages_and_choices=inputs.unsupported_assistant_tool_calls.messages_and_choices, - tools=inputs.unsupported_assistant_tool_calls.tools, - ), - advantage=1.0, - allow_training_without_logprobs=True, - trajectory=inputs.unsupported_assistant_tool_calls, - ) + assert tokenizer.apply_chat_template_kwargs[-1]["enable_thinking"] is False + assert tokenizer.apply_chat_template_kwargs[-1]["preserve_thinking"] is True diff --git a/tests/unit/test_tinker_renderers.py b/tests/unit/test_tinker_renderers.py index 35db45bef..c03f87129 100644 --- a/tests/unit/test_tinker_renderers.py +++ b/tests/unit/test_tinker_renderers.py @@ -111,7 +111,7 @@ def test_qwen3_5_parse_response_handles_xml_tool_calls() -> None: message, success = renderer.parse_response(response) - assert success is True + assert success == renderers.ParseTermination.STOP_SEQUENCE assert message["content"] == [ {"type": "thinking", "thinking": "reasoning"}, {"type": "text", "text": "Answer first.\n\n"}, diff --git a/tests/unit/test_tokenize_trajectory_groups.ipynb b/tests/unit/test_tokenize_trajectory_groups.ipynb index 7b10993e6..7f19af951 100644 --- a/tests/unit/test_tokenize_trajectory_groups.ipynb +++ b/tests/unit/test_tokenize_trajectory_groups.ipynb @@ -71,6 +71,44 @@ ], "source": [ "# NBVAL_IGNORE_OUTPUT\n", + "prompt_token_ids = [\n", + " 151644,\n", + " 8948,\n", + " 198,\n", + " 2610,\n", + " 525,\n", + " 1207,\n", + " 16948,\n", + " 11,\n", + " 3465,\n", + " 553,\n", + " 54364,\n", + " 14817,\n", + " 13,\n", + " 1446,\n", + " 525,\n", + " 264,\n", + " 10950,\n", + " 17847,\n", + " 13,\n", + " 151645,\n", + " 198,\n", + " 151644,\n", + " 872,\n", + " 198,\n", + " 3838,\n", + " 374,\n", + " 279,\n", + " 6722,\n", + " 315,\n", + " 9625,\n", + " 30,\n", + " 151645,\n", + " 198,\n", + " 151644,\n", + " 77091,\n", + " 198,\n", + "]\n", "tokenized_results = list(\n", " tokenize_trajectory_groups(\n", " tokenizer,\n", @@ -83,7 +121,19 @@ " \"role\": \"user\",\n", " \"content\": \"What is the capital of France?\",\n", " },\n", - " {\"role\": \"assistant\", \"content\": \"London\"},\n", + " Choice.model_validate(\n", + " {\n", + " \"finish_reason\": \"stop\",\n", + " \"index\": 0,\n", + " \"logprobs\": None,\n", + " \"message\": ChatCompletionMessage(\n", + " content=\"London\",\n", + " role=\"assistant\",\n", + " ),\n", + " \"prompt_token_ids\": prompt_token_ids,\n", + " \"token_ids\": [39572],\n", + " }\n", + " ),\n", " ],\n", " reward=0.0,\n", " ),\n", @@ -93,23 +143,27 @@ " \"role\": \"user\",\n", " \"content\": \"What is the capital of France?\",\n", " },\n", - " Choice(\n", - " finish_reason=\"stop\",\n", - " index=0,\n", - " logprobs=ChoiceLogprobs(\n", - " content=[\n", - " ChatCompletionTokenLogprob(\n", - " token=\"token:59604\",\n", - " bytes=[80, 97, 114, 105, 115],\n", - " logprob=-0.01,\n", - " top_logprobs=[],\n", - " )\n", - " ]\n", - " ),\n", - " message=ChatCompletionMessage(\n", - " content=\"Paris\",\n", - " role=\"assistant\",\n", - " ),\n", + " Choice.model_validate(\n", + " {\n", + " \"finish_reason\": \"stop\",\n", + " \"index\": 0,\n", + " \"logprobs\": ChoiceLogprobs(\n", + " content=[\n", + " ChatCompletionTokenLogprob(\n", + " token=\"token:59604\",\n", + " bytes=[80, 97, 114, 105, 115],\n", + " logprob=-0.01,\n", + " top_logprobs=[],\n", + " )\n", + " ]\n", + " ),\n", + " \"message\": ChatCompletionMessage(\n", + " content=\"Paris\",\n", + " role=\"assistant\",\n", + " ),\n", + " \"prompt_token_ids\": prompt_token_ids,\n", + " \"token_ids\": [59604],\n", + " }\n", " ),\n", " ],\n", " reward=1.0,\n", diff --git a/uv.lock b/uv.lock index 1509d0a09..36ad513d8 100644 --- a/uv.lock +++ b/uv.lock @@ -2,29 +2,51 @@ version = 1 revision = 3 requires-python = ">=3.12" resolution-markers = [ - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", -] + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "(python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "(python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "(python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", +] +conflicts = [[ + { package = "openpipe-art", extra = "backend" }, + { package = "openpipe-art", extra = "megatron" }, +], [ + { package = "openpipe-art", extra = "megatron" }, + { package = "openpipe-art", extra = "tinker" }, +]] [manifest] overrides = [ - { name = "flashinfer-python", specifier = "==0.6.1" }, + { name = "flashinfer-python", specifier = "==0.6.8.post1" }, { name = "megatron-core", specifier = "==0.17.0" }, { name = "numpy", specifier = "<2" }, { name = "nvidia-resiliency-ext", specifier = "<0.5" }, @@ -32,7 +54,6 @@ overrides = [ { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = "==2.11.0" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = "==2.11.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformer-engine", specifier = "==2.11.0" }, - { name = "transformers", specifier = "==5.2.0" }, ] excludes = [ "causal-conv1d", @@ -50,6 +71,11 @@ requires-dist = ["packaging"] name = "deep-ep" version = "1.2.1+9af0e0d" +[[manifest.dependency-metadata]] +name = "megatron-bridge" +version = "0.5.0+e1a207ac" +requires-dist = ["accelerate", "comet-ml", "datasets", "diffusers", "einops", "flash-linear-attention", "flashinfer-cubin", "flashinfer-python", "hydra-core", "imageio", "imageio-ffmpeg", "megatron-core", "mistral-common", "mlflow", "nvidia-resiliency-ext", "omegaconf", "open-clip-torch", "peft", "pyyaml", "qwen-vl-utils", "regex", "rich", "six", "tensorboard", "timm", "torch", "tqdm", "transformers", "typing-extensions", "wandb"] + [[manifest.dependency-metadata]] name = "transformer-engine-torch" version = "2.11.0" @@ -87,8 +113,8 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/97/33/47bbd507e3a851d33d19ce7b2141c5ea3689bfae91ba168044d7db24b0e9/accelerate-1.7.0.tar.gz", hash = "sha256:e8a2a5503d6237b9eee73cc8d36cf543f9c2d8dd2c6713450b322f5e6d53a610", size = 376026, upload-time = "2025-05-15T10:00:52.117Z" } wheels = [ @@ -213,9 +239,9 @@ wheels = [ [package.optional-dependencies] speedups = [ { name = "aiodns" }, - { name = "backports-zstd", marker = "python_full_version < '3.14' and platform_python_implementation == 'CPython'" }, - { name = "brotli", marker = "platform_python_implementation == 'CPython'" }, - { name = "brotlicffi", marker = "platform_python_implementation != 'CPython'" }, + { name = "backports-zstd", marker = "(python_full_version < '3.14' and platform_python_implementation == 'CPython') or (python_full_version >= '3.14' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (python_full_version >= '3.14' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (platform_python_implementation != 'CPython' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (platform_python_implementation != 'CPython' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "brotli", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "brotlicffi", marker = "platform_python_implementation != 'CPython' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] [[package]] @@ -236,7 +262,7 @@ version = "1.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "frozenlist" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } wheels = [ @@ -296,7 +322,7 @@ version = "4.13.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "idna" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc", size = 231622, upload-time = "2026-03-24T12:59:09.671Z" } wheels = [ @@ -562,8 +588,8 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "packaging" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/d8/7d/f1fe0992334b18cd8494f89aeec1dcc674635584fcd9f115784fea3a1d05/bitsandbytes-0.49.2-py3-none-macosx_14_0_arm64.whl", hash = "sha256:87be5975edeac5396d699ecbc39dfc47cf2c026daaf2d5852a94368611a6823f", size = 131940, upload-time = "2026-02-16T21:26:04.572Z" }, @@ -759,7 +785,7 @@ name = "cffi" version = "2.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pycparser", marker = "implementation_name != 'PyPy'" }, + { name = "pycparser", marker = "implementation_name != 'PyPy' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } wheels = [ @@ -941,7 +967,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593, upload-time = "2024-12-21T18:38:44.339Z" } wheels = [ @@ -972,7 +998,7 @@ version = "3.58.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dulwich" }, - { name = "everett", extra = ["ini"] }, + { name = "everett", extra = ["ini"], marker = "extra == 'extra-12-openpipe-art-megatron'" }, { name = "jsonschema" }, { name = "psutil" }, { name = "python-box" }, @@ -1165,7 +1191,7 @@ name = "cryptography" version = "43.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, + { name = "cffi", marker = "platform_python_implementation != 'PyPy' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0d/05/07b55d1fa21ac18c3a8c79f764e2514e6f6a9698f1be44994f5adf0d29db/cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805", size = 686989, upload-time = "2024-10-18T15:58:32.918Z" } wheels = [ @@ -1194,7 +1220,7 @@ name = "cuda-bindings" version = "12.9.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder" }, + { name = "cuda-pathfinder", marker = "sys_platform == 'linux' or sys_platform == 'win32' or extra == 'extra-12-openpipe-art-megatron'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/32/45/557d4ed1fa54f0c7db8aee083229f624990d69f7d00f55477eed5c7e169a/cuda_bindings-12.9.7-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0666d3c082ef8f4b2d670950589373550e9f3bf564d635dd883f24a0b40402ff", size = 7071026, upload-time = "2026-05-27T18:44:13.356Z" }, @@ -1233,6 +1259,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/9d/05e753afbaac3f92691059b3ba875589c98a425d69e5808cec32b31b580c/cuda_python-12.9.7-py3-none-any.whl", hash = "sha256:23a1fc406d491eef7a7e985095725cb7b20a04a7bd9b7a66400e5c86e082e0aa", size = 7597, upload-time = "2026-05-27T19:50:32.605Z" }, ] +[[package]] +name = "cuda-tile" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/93/64ef40d3982dcda7a97ebfa3e3bb9045b573d4eb3877fa5d1fa3cd2541d3/cuda_tile-1.4.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:9e358a85a153820aa0a51d0e09346d884a3c14b88c0313d20d0fb9f53952abae", size = 280953, upload-time = "2026-05-27T17:46:53.03Z" }, + { url = "https://files.pythonhosted.org/packages/d7/9a/7fbdbdb30c375f80818941165adfc4f1dc6cebaf937c6a9081a02d5871f0/cuda_tile-1.4.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:1d9d99b6fa57366af3f8707ac4fd91411275af2ee736996a60620240fcf92070", size = 282503, upload-time = "2026-05-27T17:45:05.543Z" }, + { url = "https://files.pythonhosted.org/packages/6f/bb/4152dc08a8de5bcdc4b9d80b6917216289526f6e786b09ee80d4df27bcfb/cuda_tile-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:616f13cbc7af6caa7b92430b85ba0a429d1f96ca9e7e04a29d89114cfe859663", size = 269813, upload-time = "2026-05-27T17:46:20.583Z" }, + { url = "https://files.pythonhosted.org/packages/5e/ad/42f0655e6aee5c59015634b46d7f13bc22e74af28d10fb2008a062b37349/cuda_tile-1.4.0-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:fc74185efd81f6153af0a19549d111dec6861ee9b9bc27927a2cef6e19173eb5", size = 280958, upload-time = "2026-05-27T17:46:53.061Z" }, + { url = "https://files.pythonhosted.org/packages/11/0b/4770f9e36b8108ce8c9078f71eb21c65e594d79c0770dd38daa045cfbd6c/cuda_tile-1.4.0-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:45be74f6568c440446f510bc7799b953858e64c6abf26e96f2c9598a79084860", size = 282508, upload-time = "2026-05-27T17:45:18.515Z" }, + { url = "https://files.pythonhosted.org/packages/a1/67/41f1acdf21bf6214a3a1c3b46d39b8eb0f9eba7aecc6b57005db35d56f9a/cuda_tile-1.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:edd1df4d7955032c7be2a26c6d7e47261415ba7c87587705e0f4f1fd0d61650a", size = 269783, upload-time = "2026-05-27T17:47:16.631Z" }, + { url = "https://files.pythonhosted.org/packages/0d/c6/46a329f4c56ce54471784366394e235804423df2531307e14112e4636c76/cuda_tile-1.4.0-cp314-cp314-manylinux2014_aarch64.whl", hash = "sha256:738593650784ebb3c601486914b563e7569144fe596048766ea9e12280ac3bb9", size = 281208, upload-time = "2026-05-27T17:46:48.325Z" }, + { url = "https://files.pythonhosted.org/packages/8f/fb/bf3849ad68b1858ba50e6992863d266892d7d7db02d11c485c26cd090a1b/cuda_tile-1.4.0-cp314-cp314-manylinux2014_x86_64.whl", hash = "sha256:4b1a591c26836a550c2bf87c22d31c4716e5f83d24d255f843d9429625cca973", size = 282630, upload-time = "2026-05-27T17:45:10.789Z" }, + { url = "https://files.pythonhosted.org/packages/61/bb/211c0d5121230ee76cfc1a9ee107ec28aaae9e6ffb43a04aa172d0d4f4dc/cuda_tile-1.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:c19e10fe70ba92709b6ca446d1c52a8a346b56f4f8ad7c8941736f60e32f3c87", size = 270644, upload-time = "2026-05-27T17:45:01.914Z" }, + { url = "https://files.pythonhosted.org/packages/ab/df/f7f1dfa4d1ee7cc5b69e11d756be6ffec1561a5c7e3836fd0f71ca49adcf/cuda_tile-1.4.0-cp314-cp314t-manylinux2014_aarch64.whl", hash = "sha256:b3cbeffbe0fedac4936edcf00b6ba13ab5ddb74d3b7ce4a287dfc04491b5f6af", size = 283249, upload-time = "2026-05-27T17:46:12.032Z" }, + { url = "https://files.pythonhosted.org/packages/18/c0/fee527a085fca414fc993769912eb8ba2e15ce388f3168b868706e6d4c61/cuda_tile-1.4.0-cp314-cp314t-manylinux2014_x86_64.whl", hash = "sha256:675b2afff62af5d4e72c34bc72d0be27b0933a44933b8a449f590fbded8c1107", size = 284336, upload-time = "2026-05-27T17:44:59.489Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ab/0883194457932150a5ad334d609ac17bd704345974d21c8bae6ea251e7ed/cuda_tile-1.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:3f58eac5577ea3ed7c17bfcab015a506fd2cf61f8848407c5b403f1bf46c55ca", size = 275861, upload-time = "2026-05-27T17:46:36.285Z" }, +] + [[package]] name = "cuda-toolkit" version = "12.8.1" @@ -1243,37 +1291,37 @@ wheels = [ [package.optional-dependencies] cublas = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cudart = [ - { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cufft = [ - { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cufile = [ - { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cupti = [ - { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] curand = [ - { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cusolver = [ - { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cusparse = [ - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] nvjitlink = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] nvrtc = [ - { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] nvtx = [ - { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] [[package]] @@ -1297,8 +1345,8 @@ name = "cut-cross-entropy" version = "25.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "triton", marker = "sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7e/97/45ff09cfcda7b200389204daa0125168e6544fba257adbbcdf728501d4f9/cut_cross_entropy-25.1.1.tar.gz", hash = "sha256:5fe5924509248b1aea5c890f8887c6a7759f7c8b1ebc0490e42c247c4f7c1e34", size = 22972, upload-time = "2025-01-07T12:21:53.896Z" } @@ -1844,8 +1892,8 @@ version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "einops" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/03/14/2aabd37839b9f3c6a67fbc5678f906d04d0c242c603ac234eefe02df99a6/fla_core-0.5.0.tar.gz", hash = "sha256:476dd94711702af81cc4827010d9209f6053d8cdceac8e43d3c8497071f07a81", size = 418171, upload-time = "2026-04-21T20:25:40.948Z" } wheels = [ @@ -1861,8 +1909,8 @@ dependencies = [ { name = "einops" }, { name = "nvidia-cutlass-dsl" }, { name = "quack-kernels" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch-c-dlpack-ext" }, { name = "typing-extensions" }, ] @@ -1890,13 +1938,47 @@ version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fla-core" }, - { name = "transformers" }, + { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" } }, ] sdist = { url = "https://files.pythonhosted.org/packages/79/5c/1db76cc829c951117a3112f306d50333bd71399d2e35807fe7c99ffc2007/flash_linear_attention-0.5.0.tar.gz", hash = "sha256:22b789a47f07738b4382ecdf775d7bb40e0d803c467c34f8e2ecd6a1dc780938", size = 160419, upload-time = "2026-04-21T20:25:42.344Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/cc/16/7736db08806981562c728f32ea1dcb4565948fa9faffdbf4ffbf72522fbf/flash_linear_attention-0.5.0-py3-none-any.whl", hash = "sha256:92e64e989ed34355c1f838232597b2e39783ee0494ada3199b58e156aa1d8eb8", size = 319037, upload-time = "2026-04-21T20:25:39.473Z" }, ] +[[package]] +name = "flashinfer-cubin" +version = "0.6.8.post1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/b7/5e3b1a8c67031b421a8bd29c2bc29b900a550bb3392e8bda18bb15b5e476/flashinfer_cubin-0.6.8.post1-py3-none-any.whl", hash = "sha256:43636d4cd39e694a83d76a89f87fefcdf4cecb4c4f7dd22dac25ec368c1e901f", size = 295154113, upload-time = "2026-04-18T18:28:21.738Z" }, +] + +[[package]] +name = "flashinfer-python" +version = "0.6.8.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi" }, + { name = "click" }, + { name = "cuda-tile" }, + { name = "einops" }, + { name = "ninja" }, + { name = "numpy" }, + { name = "nvidia-cudnn-frontend" }, + { name = "nvidia-cutlass-dsl" }, + { name = "nvidia-ml-py" }, + { name = "packaging" }, + { name = "requests" }, + { name = "tabulate" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/1e/2760fef9e74abc4480961048e5790b4c9e955872fb4d7d97900cfddced5a/flashinfer_python-0.6.8.post1.tar.gz", hash = "sha256:b18e4121baf9b93fa9a9f368ba9b981a0342895f50ab9dddc224aeb964ed346f", size = 6675885, upload-time = "2026-04-18T18:28:13.299Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/6d/1e8a8533913e33a50a486332ce0673f4fdb860f6eb9ed450327c5c1762cb/flashinfer_python-0.6.8.post1-py3-none-any.whl", hash = "sha256:818f9b8cc2fe66c42a1f6264be4841ac8821ada703685a02cfccb2b5124a710b", size = 9385316, upload-time = "2026-04-18T18:28:10.285Z" }, +] + [[package]] name = "flask" version = "3.1.3" @@ -2420,7 +2502,7 @@ name = "gunicorn" version = "25.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "packaging", marker = "sys_platform != 'win32'" }, + { name = "packaging" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c4/f4/e78fa054248fab913e2eab0332c6c2cb07421fca1ce56d8fe43b6aef57a4/gunicorn-25.3.0.tar.gz", hash = "sha256:f74e1b2f9f76f6cd1ca01198968bd2dd65830edc24b6e8e4d78de8320e2fe889", size = 634883, upload-time = "2026-03-27T00:00:26.092Z" } wheels = [ @@ -2454,7 +2536,7 @@ name = "hatch" version = "1.16.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "backports-zstd", marker = "python_full_version < '3.14'" }, + { name = "backports-zstd", marker = "python_full_version < '3.14' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "click" }, { name = "hatchling" }, { name = "httpx" }, @@ -2663,7 +2745,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, - { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "httpx" }, { name = "packaging" }, { name = "pyyaml" }, @@ -2870,7 +2952,7 @@ name = "ipykernel" version = "7.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -2894,12 +2976,12 @@ name = "ipython" version = "9.13.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "decorator" }, { name = "ipython-pygments-lexers" }, { name = "jedi" }, { name = "matplotlib-inline" }, - { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "pexpect", marker = "(sys_platform != 'emscripten' and sys_platform != 'win32') or (sys_platform == 'emscripten' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'emscripten' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "prompt-toolkit" }, { name = "psutil" }, { name = "pygments" }, @@ -3216,9 +3298,9 @@ dependencies = [ { name = "jaraco-classes" }, { name = "jaraco-context" }, { name = "jaraco-functools" }, - { name = "jeepney", marker = "sys_platform == 'linux'" }, - { name = "pywin32-ctypes", marker = "sys_platform == 'win32'" }, - { name = "secretstorage", marker = "sys_platform == 'linux'" }, + { name = "jeepney", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "pywin32-ctypes", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "secretstorage", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/43/4b/674af6ef2f97d56f0ab5153bf0bfa28ccb6c3ed4d1babf4305449668807b/keyring-25.7.0.tar.gz", hash = "sha256:fe01bd85eb3f8fb3dd0405defdeac9a5b4f6f0439edbb3149577f244a2e8245b", size = 63516, upload-time = "2025-11-16T16:26:09.482Z" } wheels = [ @@ -3440,7 +3522,7 @@ version = "0.8.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, - { name = "orjson", marker = "platform_python_implementation != 'PyPy'" }, + { name = "orjson", marker = "platform_python_implementation != 'PyPy' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "packaging" }, { name = "pydantic" }, { name = "requests" }, @@ -3731,8 +3813,8 @@ wheels = [ [[package]] name = "megatron-bridge" -version = "0.4.0rc0" -source = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=e049cc00c24d03e2ae45d2608c7a44e2d2364e3d#e049cc00c24d03e2ae45d2608c7a44e2d2364e3d" } +version = "0.5.0+e1a207ac" +source = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=e1a207ac757e5d0ed94d8ffbe1cbd28e81d8c084#e1a207ac757e5d0ed94d8ffbe1cbd28e81d8c084" } dependencies = [ { name = "accelerate" }, { name = "comet-ml" }, @@ -3740,10 +3822,13 @@ dependencies = [ { name = "diffusers" }, { name = "einops" }, { name = "flash-linear-attention" }, + { name = "flashinfer-cubin" }, + { name = "flashinfer-python" }, { name = "hydra-core" }, { name = "imageio" }, { name = "imageio-ffmpeg" }, { name = "megatron-core" }, + { name = "mistral-common" }, { name = "mlflow" }, { name = "nvidia-resiliency-ext" }, { name = "omegaconf" }, @@ -3756,11 +3841,10 @@ dependencies = [ { name = "six" }, { name = "tensorboard" }, { name = "timm" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, - { name = "transformer-engine" }, - { name = "transformers" }, + { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" } }, { name = "typing-extensions" }, { name = "wandb" }, ] @@ -3772,8 +3856,8 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "packaging" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/bc/89/f690c7d282200d6e36078f4bfbb9e6862102105c062fbf9b518c5b72df38/megatron_core-0.17.0.tar.gz", hash = "sha256:ff66c206ed164bc602ff00310388605fac41f284262176e17246a9e94163b205", size = 1385595, upload-time = "2026-04-16T20:22:32.079Z" } wheels = [ @@ -3783,6 +3867,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/44/0ee6bca0e8056d6daf0c21f15f74e36b2628318e19dd78dfaac185c6b547/megatron_core-0.17.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7a54ad8a8e221ba989a721da73496cc86ecd84ec79a711449060a15d690005b5", size = 1725175, upload-time = "2026-04-16T20:22:30.032Z" }, ] +[[package]] +name = "mistral-common" +version = "1.11.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "pydantic" }, + { name = "pydantic-extra-types", extra = ["pycountry"], marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "requests" }, + { name = "tiktoken" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/03/3c5d4c9430da406f8444f9a7b058a6aa89c525fb068a57fe2ab8b04a6d08/mistral_common-1.11.3.tar.gz", hash = "sha256:6437e128fc8a307318440839ca14ddf2e8060056b062233ec0db10352651374c", size = 6360629, upload-time = "2026-06-04T09:01:11.131Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/76/dbfdf9c59e2a4b0116587626a3768c2a3b2ba1758b5756743918c2337fdc/mistral_common-1.11.3-py3-none-any.whl", hash = "sha256:dbfcef9d0c892727ee08a080f0c1039baed5430b291f5425ffd88892bf09e52c", size = 6533154, upload-time = "2026-06-04T09:01:14.186Z" }, +] + [[package]] name = "ml-dtypes" version = "0.5.4" @@ -4253,6 +4356,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" }, { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, + { url = "https://files.pythonhosted.org/packages/70/61/7d7b3c70186fb651d0fbd35b01dbfc8e755f69fd58f817f3d0f642df20c3/nvidia_cublas_cu12-12.8.4.1-py3-none-win_amd64.whl", hash = "sha256:47e9b82132fa8d2b4944e708049229601448aaad7e6f296f630f2d1a32de35af", size = 567544208, upload-time = "2025-03-07T01:53:30.535Z" }, ] [[package]] @@ -4262,6 +4366,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" }, { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, + { url = "https://files.pythonhosted.org/packages/41/bc/83f5426095d93694ae39fe1311431b5d5a9bb82e48bf0dd8e19be2765942/nvidia_cuda_cupti_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:bb479dcdf7e6d4f8b0b01b115260399bf34154a1a2e9fe11c85c517d87efd98e", size = 7015759, upload-time = "2025-03-07T01:51:11.355Z" }, ] [[package]] @@ -4271,6 +4376,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, { url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" }, + { url = "https://files.pythonhosted.org/packages/45/51/52a3d84baa2136cc8df15500ad731d74d3a1114d4c123e043cb608d4a32b/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:7a4b6b2904850fe78e0bd179c4b655c404d4bb799ef03ddc60804247099ae909", size = 73586838, upload-time = "2025-03-07T01:52:13.483Z" }, ] [[package]] @@ -4280,6 +4386,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" }, { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, + { url = "https://files.pythonhosted.org/packages/30/a5/a515b7600ad361ea14bfa13fb4d6687abf500adc270f19e89849c0590492/nvidia_cuda_runtime_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:c0c6027f01505bfed6c3b21ec546f69c687689aad5f1a377554bc6ca4aa993a8", size = 944318, upload-time = "2025-03-07T01:51:01.794Z" }, ] [[package]] @@ -4287,11 +4394,12 @@ name = "nvidia-cudnn-cu12" version = "9.19.0.56" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/09/b8/277c51962ee46fa3e5b203ac5f76107c650f781d6891e681e28e6f3e9fe6/nvidia_cudnn_cu12-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:08caaf27fe556aca82a3ee3b5aa49a77e7de0cfcb7ff4e5c29da426387a8267e", size = 656910700, upload-time = "2026-02-03T20:40:25.508Z" }, { url = "https://files.pythonhosted.org/packages/c5/41/65225d42fba06fb3dd3972485ea258e7dd07a40d6e01c95da6766ad87354/nvidia_cudnn_cu12-9.19.0.56-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:ac6ad90a075bb33a94f2b4cf4622eac13dd4dc65cf6dd9c7572a318516a36625", size = 657906812, upload-time = "2026-02-03T20:44:12.638Z" }, + { url = "https://files.pythonhosted.org/packages/a7/a5/48f07449fc9c6cc146dcafe6149fa5d69630137d2ec5b7d9e09f255fadd7/nvidia_cudnn_cu12-9.19.0.56-py3-none-win_amd64.whl", hash = "sha256:cec70596b9ce878fab83810c3f5a2e606d35f510e5fee579759e4cbc68a23750", size = 644003014, upload-time = "2026-02-03T20:46:25.768Z" }, ] [[package]] @@ -4301,10 +4409,13 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/0e/eb/22b4cad479206a3824edf494582e19fc4a291b9c14febdb859e56b82c03f/nvidia_cudnn_frontend-1.20.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bb891643598ac7b3734b82e5a459cbf778e467ebf7a5b586840003fb66df0ef3", size = 2371995, upload-time = "2026-03-16T18:29:29.024Z" }, { url = "https://files.pythonhosted.org/packages/aa/83/ee43fc097f475367f1ff5d5e3e1d8191d253f486cdd502d13600759fb845/nvidia_cudnn_frontend-1.20.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce50afe3d1efda07f52e8df5e992f33e92dbb443d0e61e2de703ad5762edc53c", size = 2521021, upload-time = "2026-03-16T18:25:37.316Z" }, + { url = "https://files.pythonhosted.org/packages/cc/03/d2d725c9c6eb04cd4a3216a7d1a37ab825d2ae8822b79a78b458ab703607/nvidia_cudnn_frontend-1.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:f2449b0cfc547688e27f975c6ad5101257ae86df0315a80f28af78995adf55b6", size = 1944734, upload-time = "2026-03-16T18:33:02.866Z" }, { url = "https://files.pythonhosted.org/packages/d7/26/e5a309fe92ad67f2dc1ea85b2615f40db6c19f6a7b36b40036d57ae23a66/nvidia_cudnn_frontend-1.20.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:651fdc9a61b0a4456b557d5f82fab72739b0a6ee61384a4cb23767191e2640cd", size = 2371699, upload-time = "2026-03-16T18:30:19.865Z" }, { url = "https://files.pythonhosted.org/packages/2d/6f/a9f5df2e003ce6f57b6e609e323fc13379a0f7966d2e044de4ceb87ec4b4/nvidia_cudnn_frontend-1.20.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f317548e700f74c167fa4988de5f0ac06931820e4d0c35b5c7dfe629dd191be4", size = 2521383, upload-time = "2026-03-16T18:26:12.09Z" }, + { url = "https://files.pythonhosted.org/packages/90/8f/cba72a4deb5168bba97d0094dbfe05591a12bc9cc9432bbfd0c107ddca33/nvidia_cudnn_frontend-1.20.0-cp313-cp313-win_amd64.whl", hash = "sha256:64e5c21853732a2f6ecf031d95d100656514d43fd2260f64266b5f8536f46434", size = 1944767, upload-time = "2026-03-16T18:33:25.204Z" }, { url = "https://files.pythonhosted.org/packages/f9/a0/d2634d910257e6827d178dcebdf109f7f2bd8003659675dffc82fa101077/nvidia_cudnn_frontend-1.20.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a1cf3e86664fb64e4752d3936d9cebd0afa6c4b5f6ccde19b6ee4d65fcd9d17", size = 2373944, upload-time = "2026-03-16T18:31:06.31Z" }, { url = "https://files.pythonhosted.org/packages/79/a2/dd2a75942b0311a50bfef3173b240695a5ebdbcbd3c5154d8f333ef6dac6/nvidia_cudnn_frontend-1.20.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f4da0e9ed299843abdcccdde73392577809403d4ef2ad26b4335a3eaee42423f", size = 2522596, upload-time = "2026-03-16T18:26:34.249Z" }, + { url = "https://files.pythonhosted.org/packages/ce/af/7110cea67a8cc8f3cd129cead952f5d50078c8bb99cf35e9f78c74a27097/nvidia_cudnn_frontend-1.20.0-cp314-cp314-win_amd64.whl", hash = "sha256:3f596e54398efab24727fc47291c61f969051f37e57e186ffe0fb6df06db19fd", size = 1946060, upload-time = "2026-03-16T18:33:47.963Z" }, ] [[package]] @@ -4312,11 +4423,12 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, + { url = "https://files.pythonhosted.org/packages/7d/ec/ce1629f1e478bb5ccd208986b5f9e0316a78538dd6ab1d0484f012f8e2a1/nvidia_cufft_cu12-11.3.3.83-py3-none-win_amd64.whl", hash = "sha256:7a64a98ef2a7c47f905aaf8931b69a3a43f27c55530c698bb2ed7c75c0b42cb7", size = 192216559, upload-time = "2025-03-07T01:53:57.106Z" }, ] [[package]] @@ -4335,6 +4447,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" }, { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, + { url = "https://files.pythonhosted.org/packages/b9/75/70c05b2f3ed5be3bb30b7102b6eb78e100da4bbf6944fd6725c012831cab/nvidia_curand_cu12-10.3.9.90-py3-none-win_amd64.whl", hash = "sha256:f149a8ca457277da854f89cf282d6ef43176861926c7ac85b2a0fbd237c587ec", size = 62765309, upload-time = "2025-03-07T01:54:20.478Z" }, ] [[package]] @@ -4342,13 +4455,14 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, + { url = "https://files.pythonhosted.org/packages/13/c0/76ca8551b8a84146ffa189fec81c26d04adba4bc0dbe09cd6e6fd9b7de04/nvidia_cusolver_cu12-11.7.3.90-py3-none-win_amd64.whl", hash = "sha256:4a550db115fcabc4d495eb7d39ac8b58d4ab5d8e63274d3754df1c0ad6a22d34", size = 256720438, upload-time = "2025-03-07T01:54:39.898Z" }, ] [[package]] @@ -4356,11 +4470,12 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, + { url = "https://files.pythonhosted.org/packages/62/07/f3b2ad63f8e3d257a599f422ae34eb565e70c41031aecefa3d18b62cabd1/nvidia_cusparse_cu12-12.5.8.93-py3-none-win_amd64.whl", hash = "sha256:9a33604331cb2cac199f2e7f5104dfbb8a5a898c367a53dfda9ff2acb6b6b4dd", size = 284937404, upload-time = "2025-03-07T01:55:07.742Z" }, ] [[package]] @@ -4370,6 +4485,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" }, { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, + { url = "https://files.pythonhosted.org/packages/2f/d8/a6b0d0d0c2435e9310f3e2bb0d9c9dd4c33daef86aa5f30b3681defd37ea/nvidia_cusparselt_cu12-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f67fbb5831940ec829c9117b7f33807db9f9678dc2a617fbe781cac17b4e1075", size = 271020911, upload-time = "2025-02-26T00:14:47.204Z" }, ] [[package]] @@ -4430,8 +4546,8 @@ dependencies = [ { name = "safetensors" }, { name = "scipy" }, { name = "setuptools" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, ] wheels = [ @@ -4454,6 +4570,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" }, + { url = "https://files.pythonhosted.org/packages/ed/d7/34f02dad2e30c31b10a51f6b04e025e5dd60e5f936af9045a9b858a05383/nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f", size = 268553710, upload-time = "2025-03-07T01:56:24.13Z" }, ] [[package]] @@ -4472,6 +4589,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" }, { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, + { url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" }, ] [[package]] @@ -4484,8 +4602,8 @@ dependencies = [ { name = "packaging" }, { name = "psutil" }, { name = "pyyaml" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/70/05/38d491962273c7905708762279f440520eb79f3c00b67a023497215ad023/nvidia_resiliency_ext-0.4.1-cp312-cp312-manylinux_2_31_aarch64.whl", hash = "sha256:b3bd5f01535574b16d0f38bca6e39afe3806c4a2896eee1b321cd944e00025a7", size = 444570, upload-time = "2025-07-17T03:50:58.877Z" }, @@ -4587,8 +4705,8 @@ dependencies = [ { name = "regex" }, { name = "safetensors" }, { name = "timm" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchvision" }, { name = "tqdm" }, ] @@ -4647,10 +4765,10 @@ backend = [ { name = "pyarrow" }, { name = "pytest" }, { name = "setuptools" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchao" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, { name = "trl" }, { name = "unsloth" }, { name = "unsloth-zoo" }, @@ -4676,11 +4794,12 @@ megatron = [ { name = "pybind11" }, { name = "quack-kernels" }, { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "transformer-engine" }, { name = "transformer-engine-cu12" }, { name = "transformer-engine-torch" }, + { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" } }, ] plotting = [ { name = "matplotlib" }, @@ -4697,9 +4816,9 @@ tinker = [ { name = "pydantic" }, { name = "tinker" }, { name = "tinker-cookbook" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "transformers" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, { name = "uvicorn" }, ] @@ -4743,7 +4862,7 @@ requires-dist = [ { name = "litellm", specifier = ">=1.71.1,<=1.82.0" }, { name = "mamba-ssm", marker = "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'megatron'", specifier = "==2.3.1" }, { name = "matplotlib", marker = "extra == 'plotting'", specifier = ">=3.10.1" }, - { name = "megatron-bridge", marker = "extra == 'megatron'", git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=e049cc00c24d03e2ae45d2608c7a44e2d2364e3d" }, + { name = "megatron-bridge", marker = "extra == 'megatron'", git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=e1a207ac757e5d0ed94d8ffbe1cbd28e81d8c084" }, { name = "megatron-core", marker = "extra == 'megatron'", specifier = "==0.17.0" }, { name = "ml-dtypes", marker = "python_full_version < '3.13' and extra == 'megatron'", specifier = ">=0.5.0" }, { name = "nbclient", marker = "extra == 'backend'", specifier = ">=0.10.1" }, @@ -4775,18 +4894,19 @@ requires-dist = [ { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'megatron'", specifier = "==0.1.10" }, { name = "tinker", marker = "extra == 'tinker'", specifier = ">=0.21.0,<0.22" }, { name = "tinker-cookbook", marker = "extra == 'tinker'", specifier = ">=0.4.1,<0.5" }, - { name = "torch", marker = "(sys_platform == 'linux' and extra == 'backend') or (sys_platform == 'win32' and extra == 'backend')", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torch", marker = "(sys_platform == 'linux' and extra == 'megatron') or (sys_platform == 'win32' and extra == 'megatron')", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torch", marker = "(sys_platform == 'linux' and extra == 'tinker') or (sys_platform == 'win32' and extra == 'tinker')", specifier = ">=2.11.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'backend'", specifier = ">=2.11.0" }, - { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'megatron'", specifier = ">=2.11.0" }, - { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'tinker'", specifier = ">=2.11.0" }, + { name = "torch", marker = "(sys_platform == 'linux' and extra == 'backend') or (sys_platform == 'win32' and extra == 'backend')", specifier = "==2.11.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "(sys_platform == 'linux' and extra == 'megatron') or (sys_platform == 'win32' and extra == 'megatron')", specifier = "==2.11.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "(sys_platform == 'linux' and extra == 'tinker') or (sys_platform == 'win32' and extra == 'tinker')", specifier = "==2.11.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'backend'", specifier = "==2.11.0" }, + { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'megatron'", specifier = "==2.11.0" }, + { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'tinker'", specifier = "==2.11.0" }, { name = "torchao", marker = "extra == 'backend'", specifier = "==0.16.0" }, { name = "transformer-engine", marker = "extra == 'megatron'", specifier = "==2.11.0" }, { name = "transformer-engine-cu12", marker = "extra == 'megatron'", specifier = "==2.11.0" }, { name = "transformer-engine-torch", marker = "extra == 'megatron'", git = "https://github.com/NVIDIA/TransformerEngine.git?subdirectory=transformer_engine%2Fpytorch&rev=v2.11" }, { name = "transformers", marker = "extra == 'backend'", specifier = "==5.2.0" }, - { name = "transformers", marker = "extra == 'tinker'", specifier = "==5.2.0" }, + { name = "transformers", marker = "extra == 'megatron'", specifier = "==5.6.2" }, + { name = "transformers", marker = "extra == 'tinker'", specifier = ">=5.2.0,<=5.5.3" }, { name = "trl", marker = "extra == 'backend'", specifier = "==0.20.0" }, { name = "typer", specifier = ">=0.15.2" }, { name = "unsloth", marker = "extra == 'backend'", specifier = "==2026.3.3" }, @@ -5112,10 +5232,11 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-12-openpipe-art-backend' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-12-openpipe-art-megatron'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/86/cf/037f1e3d5186496c05513a6754639e2dab3038a05f384284d49a9bd06a2d/peft-0.19.1.tar.gz", hash = "sha256:0d97542fe96dcdaa20d3b81c06f26f988618f416a73544ab23c3618ccb674a40", size = 763738, upload-time = "2026-04-16T15:46:45.105Z" } wheels = [ @@ -5327,7 +5448,7 @@ dependencies = [ { name = "networkx" }, { name = "pdfminer-six" }, { name = "pillow" }, - { name = "pyreadline3", marker = "sys_platform == 'win32'" }, + { name = "pyreadline3", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "pyyaml" }, ] sdist = { url = "https://files.pythonhosted.org/packages/70/55/e5400762e3884f743d59291e71eaaa9c52dd7e144b75a11911e74ec1bac9/polyfile_weave-0.5.9.tar.gz", hash = "sha256:12341fab03e06ede1bfebbd3627dd24015fde5353ea74ece2da186321b818bdb", size = 6024974, upload-time = "2026-01-22T22:08:48.081Z" } @@ -5763,6 +5884,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/e3/0f15da0fb5864a37637820e4bde463a52ba0c052a8edab06aad46b9e578b/pycasbin-2.8.0-py3-none-any.whl", hash = "sha256:1a9e370de553c677c4dff75a5d6f3b0eb354b73b20d7df77ff4ee61a71267a3a", size = 476153, upload-time = "2026-02-02T03:34:12.555Z" }, ] +[[package]] +name = "pycountry" +version = "26.2.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/de/1d/061b9e7a48b85cfd69f33c33d2ef784a531c359399ad764243399673c8f5/pycountry-26.2.16.tar.gz", hash = "sha256:5b6027d453fcd6060112b951dd010f01f168b51b4bf8a1f1fc8c95c8d94a0801", size = 7711342, upload-time = "2026-02-17T03:42:52.367Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/42/7703bd45b62fecd44cd7d3495423097e2f7d28bc2e99e7c1af68892ab157/pycountry-26.2.16-py3-none-any.whl", hash = "sha256:115c4baf7cceaa30f59a4694d79483c9167dbce7a9de4d3d571c5f3ea77c305a", size = 8044600, upload-time = "2026-02-17T03:42:49.777Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -5910,6 +6040,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, ] +[package.optional-dependencies] +pycountry = [ + { name = "pycountry" }, +] + [[package]] name = "pydantic-settings" version = "2.14.1" @@ -5968,7 +6103,7 @@ name = "pynacl" version = "1.6.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, + { name = "cffi", marker = "platform_python_implementation != 'PyPy' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/4019b524b03a13438637b11538c82781a5eda427394380381af8f04f467a/pynacl-1.6.2.tar.gz", hash = "sha256:018494d6d696ae03c7e656e5e74cdfd8ea1326962cc401bcf018f1ed8436811c", size = 3511692, upload-time = "2026-01-01T17:48:10.851Z" } wheels = [ @@ -6042,7 +6177,7 @@ name = "pytest" version = "9.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "iniconfig" }, { name = "packaging" }, { name = "pluggy" }, @@ -6059,7 +6194,7 @@ version = "1.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/43/7c/d36d04db312ecf4298932ef77e6e4a9e8ad017906e24e34f0b0c361a2473/pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42", size = 58514, upload-time = "2026-05-26T09:56:04.083Z" } wheels = [ @@ -6245,7 +6380,7 @@ name = "pyzmq" version = "27.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "implementation_name == 'pypy'" }, + { name = "cffi", marker = "implementation_name == 'pypy' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } wheels = [ @@ -6290,8 +6425,8 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "apache-tvm-ffi" }, { name = "nvidia-cutlass-dsl" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch-c-dlpack-ext" }, ] sdist = { url = "https://files.pythonhosted.org/packages/95/11/6b1664d0e85f91f4549403d4ca6c9248857080f571397da7cb7570338dcd/quack_kernels-0.3.7.tar.gz", hash = "sha256:1c35a3f6f8c06b38cdf6a68d95fbb52e2b75cd261d0f01abcb7cec5d1bd80ca1", size = 193338, upload-time = "2026-03-27T19:55:55.544Z" } @@ -6321,7 +6456,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "rpds-py" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } wheels = [ @@ -6896,8 +7031,8 @@ name = "secretstorage" version = "3.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cryptography", marker = "sys_platform == 'linux'" }, - { name = "jeepney", marker = "sys_platform == 'linux'" }, + { name = "cryptography", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "jeepney", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1c/03/e834bcd866f2f8a49a85eaff47340affa3bfa391ee9912a952a1faa68c7b/secretstorage-3.5.0.tar.gz", hash = "sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be", size = 19884, upload-time = "2025-11-23T19:02:53.191Z" } wheels = [ @@ -7337,7 +7472,7 @@ name = "sqlalchemy" version = "2.0.50" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, + { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/57/da/6fbf010c8ebb347679d0d100b22fe9ba5e13fd04046c5df7280d2f0bf706/sqlalchemy-2.0.50.tar.gz", hash = "sha256:af5607d11ef90fd6a5c0549fe0045dce1663d427426bcfb506dcb5346a85a3b9", size = 9907424, upload-time = "2026-05-24T19:20:04.018Z" } @@ -7415,7 +7550,7 @@ version = "0.52.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } wheels = [ @@ -7561,20 +7696,23 @@ name = "tilelang" version = "0.1.10" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "apache-tvm-ffi", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "cloudpickle", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "ml-dtypes", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "numpy", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "psutil", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "torch-c-dlpack-ext", marker = "python_full_version < '3.14' and platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "tqdm", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, - { name = "z3-solver", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "apache-tvm-ffi", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "cloudpickle", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "ml-dtypes", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "psutil", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch-c-dlpack-ext", marker = "(python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" }, + { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "z3-solver", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/77/5c/07146b4527656102e48d21c2599aa80477e83ea3f149ac0df3b15a247bd4/tilelang-0.1.10.tar.gz", hash = "sha256:d8813e668fcf75843bc2d68c633c352b419c1e292895a6038a4aadd943e56c2b", size = 93184128, upload-time = "2026-05-25T03:58:57.006Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/76/0f/e5e01399adb5110bf885e19e879229e3fc578e1e035939f601365305c825/tilelang-0.1.10-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:246084babf0f6801ad2b8ac1d58cead37520974ae399247c89d42b68872d2cf9", size = 38492226, upload-time = "2026-05-25T03:55:30.729Z" }, { url = "https://files.pythonhosted.org/packages/b0/66/ab4301dc38ca9f09832df2936c73388c611c198dc938634acb6ce80dfa74/tilelang-0.1.10-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85180d1a96defeecdf52d5d075a31c3fc551d8485981e6b636762a9cd7eb02fe", size = 49768455, upload-time = "2026-05-25T03:56:17.081Z" }, + { url = "https://files.pythonhosted.org/packages/92/af/a3dfc43dad228a6e560863f071865d5a27c35b050a9fc431641cb07135d1/tilelang-0.1.10-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:15437e5f0daa0863ac9a5386007847c94070f5ab3234d040bc947afdf2f57100", size = 45629488, upload-time = "2026-05-25T03:57:00.374Z" }, + { url = "https://files.pythonhosted.org/packages/c3/36/2096dce95c20e13be5b5ce852190ca4b4ac41c7fd9b91a0be98353598153/tilelang-0.1.10-cp38-abi3-win_amd64.whl", hash = "sha256:93dd078113d275352698a6e72a91e80e5b0263d22a005109b3db2c1c016ea105", size = 33692452, upload-time = "2026-05-25T03:57:32.576Z" }, ] [[package]] @@ -7585,8 +7723,8 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/08/54/ece85b0eef3700c90db8271a43669b05a0ebbe2edb1962329c34374a297e/timm-1.0.27.tar.gz", hash = "sha256:315dfe63186ca9fb7ff941268941231fd5be259f2b4bb4afa28560ae1015cb9a", size = 2439861, upload-time = "2026-05-08T19:38:36.844Z" } @@ -7603,14 +7741,14 @@ dependencies = [ { name = "click" }, { name = "distro" }, { name = "grpcio" }, - { name = "httpx", extra = ["http2"] }, + { name = "httpx", extra = ["http2"], marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "numpy" }, { name = "orjson" }, { name = "protobuf" }, { name = "pydantic" }, { name = "rich" }, { name = "sniffio" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/bc/7a/a72cb2b487a7581cc192f73fd64250d14434702c4e83b2da3d5924d5ecbc/tinker-0.21.0.tar.gz", hash = "sha256:8d72709fb639f74bf90f1d1fd57beec53bfc147a768a8f42e5d6b4404eeccce9", size = 251660, upload-time = "2026-05-19T00:24:02.569Z" } @@ -7638,10 +7776,10 @@ dependencies = [ { name = "termcolor" }, { name = "tiktoken" }, { name = "tinker" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, ] sdist = { url = "https://files.pythonhosted.org/packages/e5/9c/37af9804cb3f1d88f5e67512aa1aeafeb49ef9012532d056d92c96194320/tinker_cookbook-0.4.1.tar.gz", hash = "sha256:1f9ad977317529bbf796f40ef13de59b2c93a0a257469bd80a7ffcfed5beb8b2", size = 4517724, upload-time = "2026-05-12T03:49:19.6Z" } wheels = [ @@ -7742,12 +7880,18 @@ name = "torch" version = "2.11.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", ] dependencies = [ { name = "filelock", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, @@ -7760,10 +7904,25 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/6f/8b/69e3008d78e5cee2b30183340cc425081b78afc5eff3d080daab0adda9aa/torch-2.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b5866312ee6e52ea625cd211dcb97d6a2cdc1131a5f15cc0d87eec948f6dd34", size = 80606338, upload-time = "2026-03-23T18:11:34.781Z" }, + { url = "https://files.pythonhosted.org/packages/13/16/42e5915ebe4868caa6bac83a8ed59db57f12e9a61b7d749d584776ed53d5/torch-2.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f99924682ef0aa6a4ab3b1b76f40dc6e273fca09f367d15a524266db100a723f", size = 419731115, upload-time = "2026-03-23T18:11:06.944Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c9/82638ef24d7877510f83baf821f5619a61b45568ce21c0a87a91576510aa/torch-2.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:0f68f4ac6d95d12e896c3b7a912b5871619542ec54d3649cf48cc1edd4dd2756", size = 530712279, upload-time = "2026-03-23T18:10:31.481Z" }, + { url = "https://files.pythonhosted.org/packages/1c/ff/6756f1c7ee302f6d202120e0f4f05b432b839908f9071157302cedfc5232/torch-2.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbf39280699d1b869f55eac536deceaa1b60bd6788ba74f399cc67e60a5fab10", size = 114556047, upload-time = "2026-03-23T18:10:55.931Z" }, { url = "https://files.pythonhosted.org/packages/87/89/5ea6722763acee56b045435fb84258db7375c48165ec8be7880ab2b281c5/torch-2.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1e6debd97ccd3205bbb37eb806a9d8219e1139d15419982c09e23ef7d4369d18", size = 80606801, upload-time = "2026-03-23T18:10:18.649Z" }, + { url = "https://files.pythonhosted.org/packages/32/d1/8ed2173589cbfe744ed54e5a73efc107c0085ba5777ee93a5f4c1ab90553/torch-2.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:63a68fa59de8f87acc7e85a5478bb2dddbb3392b7593ec3e78827c793c4b73fd", size = 419732382, upload-time = "2026-03-23T18:08:30.835Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e1/b73f7c575a4b8f87a5928f50a1e35416b5e27295d8be9397d5293e7e8d4c/torch-2.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:cc89b9b173d9adfab59fd227f0ab5e5516d9a52b658ae41d64e59d2e55a418db", size = 530711509, upload-time = "2026-03-23T18:08:47.213Z" }, + { url = "https://files.pythonhosted.org/packages/66/82/3e3fcdd388fbe54e29fd3f991f36846ff4ac90b0d0181e9c8f7236565f82/torch-2.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:4dda3b3f52d121063a731ddb835f010dc137b920d7fec2778e52f60d8e4bf0cd", size = 114555842, upload-time = "2026-03-23T18:09:52.111Z" }, { url = "https://files.pythonhosted.org/packages/db/38/8ac78069621b8c2b4979c2f96dc8409ef5e9c4189f6aac629189a78677ca/torch-2.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8b394322f49af4362d4f80e424bcaca7efcd049619af03a4cf4501520bdf0fb4", size = 80959574, upload-time = "2026-03-23T18:10:14.214Z" }, + { url = "https://files.pythonhosted.org/packages/6d/6c/56bfb37073e7136e6dd86bfc6af7339946dd684e0ecf2155ac0eee687ae1/torch-2.11.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2658f34ce7e2dabf4ec73b45e2ca68aedad7a5be87ea756ad656eaf32bf1e1ea", size = 419732324, upload-time = "2026-03-23T18:09:36.604Z" }, + { url = "https://files.pythonhosted.org/packages/07/f4/1b666b6d61d3394cca306ea543ed03a64aad0a201b6cd159f1d41010aeb1/torch-2.11.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:98bb213c3084cfe176302949bdc360074b18a9da7ab59ef2edc9d9f742504778", size = 530596026, upload-time = "2026-03-23T18:09:20.842Z" }, + { url = "https://files.pythonhosted.org/packages/48/6b/30d1459fa7e4b67e9e3fe1685ca1d8bb4ce7c62ef436c3a615963c6c866c/torch-2.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a97b94bbf62992949b4730c6cd2cc9aee7b335921ee8dc207d930f2ed09ae2db", size = 114793702, upload-time = "2026-03-23T18:09:47.304Z" }, { url = "https://files.pythonhosted.org/packages/26/0d/8603382f61abd0db35841148ddc1ffd607bf3100b11c6e1dab6d2fc44e72/torch-2.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:01018087326984a33b64e04c8cb5c2795f9120e0d775ada1f6638840227b04d7", size = 80573442, upload-time = "2026-03-23T18:09:10.117Z" }, + { url = "https://files.pythonhosted.org/packages/c7/86/7cd7c66cb9cec6be330fff36db5bd0eef386d80c031b581ec81be1d4b26c/torch-2.11.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:2bb3cc54bd0dea126b0060bb1ec9de0f9c7f7342d93d436646516b0330cd5be7", size = 419749385, upload-time = "2026-03-23T18:07:33.77Z" }, + { url = "https://files.pythonhosted.org/packages/47/e8/b98ca2d39b2e0e4730c0ee52537e488e7008025bc77ca89552ff91021f7c/torch-2.11.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4dc8b3809469b6c30b411bb8c4cad3828efd26236153d9beb6a3ec500f211a60", size = 530716756, upload-time = "2026-03-23T18:07:50.02Z" }, + { url = "https://files.pythonhosted.org/packages/78/88/d4a4cda8362f8a30d1ed428564878c3cafb0d87971fbd3947d4c84552095/torch-2.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:2b4e811728bd0cc58fb2b0948fe939a1ee2bf1422f6025be2fca4c7bd9d79718", size = 114552300, upload-time = "2026-03-23T18:09:05.617Z" }, { url = "https://files.pythonhosted.org/packages/bf/46/4419098ed6d801750f26567b478fc185c3432e11e2cad712bc6b4c2ab0d0/torch-2.11.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8245477871c3700d4370352ffec94b103cfcb737229445cf9946cddb7b2ca7cd", size = 80959460, upload-time = "2026-03-23T18:09:00.818Z" }, + { url = "https://files.pythonhosted.org/packages/fd/66/54a56a4a6ceaffb567231994a9745821d3af922a854ed33b0b3a278e0a99/torch-2.11.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:ab9a8482f475f9ba20e12db84b0e55e2f58784bdca43a854a6ccd3fd4b9f75e6", size = 419735835, upload-time = "2026-03-23T18:07:18.974Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e7/0b6665f533aa9e337662dc190425abc0af1fe3234088f4454c52393ded61/torch-2.11.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:563ed3d25542d7e7bbc5b235ccfacfeb97fb470c7fee257eae599adb8005c8a2", size = 530613405, upload-time = "2026-03-23T18:08:07.014Z" }, + { url = "https://files.pythonhosted.org/packages/cf/bf/c8d12a2c86dbfd7f40fb2f56fbf5a505ccf2d9ce131eb559dfc7c51e1a04/torch-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b2a43985ff5ef6ddd923bbcf99943e5f58059805787c5c9a2622bf05ca2965b0", size = 114792991, upload-time = "2026-03-23T18:08:19.216Z" }, ] [[package]] @@ -7771,34 +7930,43 @@ name = "torch" version = "2.11.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } resolution-markers = [ - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", -] -dependencies = [ - { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, - { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux'" }, - { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" }, - { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "triton", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra != 'extra-12-openpipe-art-tinker'", + "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "(python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "(python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "(python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", +] +dependencies = [ + { name = "cuda-bindings", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "triton", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9c8f38efee365cb9d334de8a83ce52fc7e5fc9e5a7b0853285efa1b69e00b0f2", upload-time = "2026-04-27T17:41:30Z" }, @@ -7823,8 +7991,8 @@ name = "torch-c-dlpack-ext" version = "0.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/37/de/921b6491efce5c389a5ef9bbed3d2d6660005840dae488124173180859ab/torch_c_dlpack_ext-0.1.5.tar.gz", hash = "sha256:d06f0357d575d22a168cc77acb9020fc4bae30968ceb6718a055dcbe92bacabe", size = 12913, upload-time = "2026-01-12T11:25:08.484Z" } wheels = [ @@ -7858,8 +8026,8 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "pillow" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c8/5cd91932f7f3671b0743dc4ae1a4c16b1d0b45bf4087976277d325bda718/torchvision-0.27.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:1a6dd742a150645126df9e0b2e449874c1d635897c773b322c2e067e98382dfe", size = 1758824, upload-time = "2026-05-13T14:57:15.227Z" }, @@ -7906,7 +8074,7 @@ name = "tqdm" version = "4.67.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/09/a9/6ba95a270c6f1fbcd8dac228323f2777d886cb206987444e4bce66338dd4/tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb", size = 169598, upload-time = "2026-02-03T17:35:53.048Z" } wheels = [ @@ -7966,8 +8134,8 @@ dependencies = [ { name = "onnxscript" }, { name = "packaging" }, { name = "pydantic" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "transformer-engine-cu12" }, ] @@ -7975,22 +8143,79 @@ dependencies = [ name = "transformers" version = "5.2.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "huggingface-hub" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "pyyaml" }, - { name = "regex" }, - { name = "safetensors" }, - { name = "tokenizers" }, - { name = "tqdm" }, - { name = "typer-slim" }, +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "(python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "(python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "(python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", + "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", + "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", +] +dependencies = [ + { name = "huggingface-hub", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "numpy", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "packaging", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "pyyaml", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "regex", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "safetensors", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "tokenizers", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "tqdm", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "typer-slim", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/bd/7e/8a0c57d562015e5b16c97c1f0b8e0e92ead2c7c20513225dc12c2043ba9f/transformers-5.2.0.tar.gz", hash = "sha256:0088b8b46ccc9eff1a1dca72b5d618a5ee3b1befc3e418c9512b35dea9f9a650", size = 8618176, upload-time = "2026-02-16T18:54:02.867Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4e/93/79754b0ca486e556c2b95d4f5afc66aaf4b260694f3d6e1b51da2d036691/transformers-5.2.0-py3-none-any.whl", hash = "sha256:9ecaf243dc45bee11a7d93f8caf03746accc0cb069181bbf4ad8566c53e854b4", size = 10403304, upload-time = "2026-02-16T18:53:59.699Z" }, ] +[[package]] +name = "transformers" +version = "5.6.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", +] +dependencies = [ + { name = "huggingface-hub", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "numpy", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "packaging", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "pyyaml", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "regex", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "safetensors", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "tokenizers", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "tqdm", marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "typer", marker = "extra == 'extra-12-openpipe-art-megatron'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a4/e9/c6c80a07690142a7d05444271f47b9f3c8aac7dea01d52e1137ee480ad78/transformers-5.6.2.tar.gz", hash = "sha256:e657134c3e5a6bc00a3c35f4e2674bb51adfcd89898495b788a18552bac2b91a", size = 8311867, upload-time = "2026-04-23T18:33:29.332Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/95/0b0218149b0d6f14df35f5b8f676fa83df4f19ed253c3cc447107ef86eca/transformers-5.6.2-py3-none-any.whl", hash = "sha256:f8d3a1bb96778fed9b8aabfd0dd6e19843e4b0f2bb6b59f32b8a92051b0f348f", size = 10364898, upload-time = "2026-04-23T18:33:26.081Z" }, +] + [[package]] name = "triton" version = "3.6.0" @@ -8025,7 +8250,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "accelerate" }, { name = "datasets" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, ] sdist = { url = "https://files.pythonhosted.org/packages/60/11/95cf1210df9f241b7b1084abe1032e322374f667c4587c09af8d14a1d76f/trl-0.20.0.tar.gz", hash = "sha256:3f949b009b79dc609cd8f5469d67209ab8f71c5cb4d8d979f7b568ef054922fa", size = 461791, upload-time = "2025-07-29T04:10:06.305Z" } wheels = [ @@ -8083,7 +8308,7 @@ version = "0.26.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "rich" }, { name = "shellingham" }, ] @@ -8097,7 +8322,7 @@ name = "typer-slim" version = "0.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typer" }, + { name = "typer", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a7/a7/e6aecc4b4eb59598829a3b5076a93aff291b4fdaa2ded25efc4e1f4d219c/typer_slim-0.24.0.tar.gz", hash = "sha256:f0ed36127183f52ae6ced2ecb2521789995992c521a46083bfcdbb652d22ad34", size = 4776, upload-time = "2026-02-16T22:08:51.2Z" } wheels = [ @@ -8177,11 +8402,11 @@ dependencies = [ { name = "protobuf" }, { name = "psutil" }, { name = "sentencepiece" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchvision" }, { name = "tqdm" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, { name = "triton", marker = "'linux' in sys_platform" }, { name = "triton-windows", marker = "(platform_machine == 'AMD64' and sys_platform == 'win32') or (platform_machine == 'x86_64' and sys_platform == 'win32')" }, { name = "trl" }, @@ -8215,11 +8440,11 @@ dependencies = [ { name = "psutil" }, { name = "regex" }, { name = "sentencepiece" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchao" }, { name = "tqdm" }, - { name = "transformers" }, + { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, { name = "triton", marker = "'linux' in sys_platform" }, { name = "trl" }, { name = "typing-extensions" }, @@ -8380,11 +8605,11 @@ wheels = [ [package.optional-dependencies] standard = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "httptools" }, { name = "python-dotenv" }, { name = "pyyaml" }, - { name = "uvloop", marker = "platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'" }, + { name = "uvloop", marker = "(platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32') or (platform_python_implementation == 'PyPy' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (platform_python_implementation == 'PyPy' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'cygwin' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'cygwin' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "watchfiles" }, { name = "websockets" }, ] @@ -8623,7 +8848,7 @@ dependencies = [ { name = "pydantic" }, { name = "sentry-sdk" }, { name = "tenacity" }, - { name = "tzdata", marker = "sys_platform == 'win32'" }, + { name = "tzdata", marker = "sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/7c/f0c54919dc390beaf33086e15abdc1b8499c6273c2035d73703ed8a0b9d6/weave-0.52.41.tar.gz", hash = "sha256:59159952f9c7c65d78dd4f7a96bfc13accb2f3d93cb43583af6c6d05c5036b4d", size = 937328, upload-time = "2026-05-19T22:03:03.124Z" } wheels = [ @@ -8792,9 +9017,9 @@ name = "xformers" version = "0.0.35" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "platform_machine != 's390x'" }, - { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 's390x' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform == 'win32')" }, + { name = "numpy" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/de/5a/6e27734bd793adc44d0b8d294e67cfacf4ec590572c1aef51d683fc7a791/xformers-0.0.35.tar.gz", hash = "sha256:f7fc183a58e4bf0e2ae339a18fb1b1d4a37854c0f2545b4f360fef001646ab76", size = 4258182, upload-time = "2026-02-20T20:33:05.417Z" } wheels = [ @@ -9003,7 +9228,12 @@ version = "4.15.4.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/8a/8e/0c8f17309549d2e5cde9a3ccefa6365437f1e7bafe71878eaf9478e47b18/z3_solver-4.15.4.0.tar.gz", hash = "sha256:928c29b58c4eb62106da51c1914f6a4a55d0441f8f48a81b9da07950434a8946", size = 5018600, upload-time = "2025-10-29T18:12:03.062Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/63/33/a3d5d2eaeb0f7b3174d57d405437eabb2075d4d50bd9ea0957696c435c7b/z3_solver-4.15.4.0-py3-none-macosx_13_0_arm64.whl", hash = "sha256:407e825cc9211f95ef46bdc8d151bf630e7ab2d62a21d24cd74c09cc5b73f3aa", size = 37052538, upload-time = "2025-10-29T18:11:46.233Z" }, + { url = "https://files.pythonhosted.org/packages/47/84/fd7ffac1551cd9f8d44fe41358f738be670fc4c24dfd514fab503f2cf3e7/z3_solver-4.15.4.0-py3-none-macosx_13_0_x86_64.whl", hash = "sha256:00bd10c5a6a5f6112d3a9a810d0799227e52f76caa860dafa5e00966bb47eb13", size = 39807925, upload-time = "2025-10-29T18:11:49.81Z" }, { url = "https://files.pythonhosted.org/packages/21/c9/bb51a96af0091324c81b803f16c49f719f9f6ea0b0bb52200f5c97ec4892/z3_solver-4.15.4.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e103a6f203f505b8b8b8e5c931cc407c95b61556512d4921c1ddc0b3f41b08e", size = 29268352, upload-time = "2025-10-29T18:11:53.032Z" }, + { url = "https://files.pythonhosted.org/packages/bf/2e/0b49f7e4e53817cfb09a0f6585012b782dfe0b666e8abefcb4fac0570606/z3_solver-4.15.4.0-py3-none-manylinux_2_34_aarch64.whl", hash = "sha256:62c7e9cbdd711932301f29919ad9158de9b2f58b4d281dd259bbcd0a2f408ba1", size = 27226534, upload-time = "2025-10-29T18:11:55.59Z" }, + { url = "https://files.pythonhosted.org/packages/26/91/33de49538444d4aafbe47415c450c2f9abab1733e1226f276b496672f46c/z3_solver-4.15.4.0-py3-none-win32.whl", hash = "sha256:be3bc916545c96ffbf89e00d07104ff14f78336e55db069177a1bfbcc01b269d", size = 13191672, upload-time = "2025-10-29T18:11:58.424Z" }, + { url = "https://files.pythonhosted.org/packages/03/d6/a0b135e4419df475177ae78fc93c422430b0fd8875649486f9a5989772e6/z3_solver-4.15.4.0-py3-none-win_amd64.whl", hash = "sha256:00e35b02632ed085ea8199fb230f6015e6fc40554a6680c097bd5f060e827431", size = 16259597, upload-time = "2025-10-29T18:12:01.14Z" }, ] [[package]] diff --git a/vllm_runtime/pyproject.toml b/vllm_runtime/pyproject.toml index 7d8bed9e5..673f66585 100644 --- a/vllm_runtime/pyproject.toml +++ b/vllm_runtime/pyproject.toml @@ -6,7 +6,7 @@ requires-python = ">=3.12,<3.13" dependencies = [ "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", "transformers==5.6.2", - "vllm @ https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl ; sys_platform == 'linux'", + "vllm @ https://github.com/vllm-project/vllm/releases/download/v0.23.0/vllm-0.23.0%2Bcu129-cp38-abi3-manylinux_2_28_x86_64.whl ; sys_platform == 'linux'", ] [project.scripts] @@ -31,7 +31,7 @@ allow-direct-references = true [tool.uv] required-version = ">=0.6.15" override-dependencies = [ - "flashinfer-python==0.6.8.post1", + "flashinfer-python==0.6.12", "numpy<2", "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", "torch @ https://download.pytorch.org/whl/test/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", diff --git a/vllm_runtime/src/art_vllm_runtime/gemma4_moe_lora_patch.py b/vllm_runtime/src/art_vllm_runtime/gemma4_moe_lora_patch.py new file mode 100644 index 000000000..9da3be581 --- /dev/null +++ b/vllm_runtime/src/art_vllm_runtime/gemma4_moe_lora_patch.py @@ -0,0 +1,45 @@ +"""Gemma4 MoE LoRA compatibility for ART's vLLM runtime.""" + +from typing import Any + + +def patch_gemma4_moe_lora_support() -> None: + """Expose Gemma4's FusedMoE metadata to vLLM's native LoRA path.""" + from vllm.model_executor.layers.fused_moe import ( + fused_moe_make_expert_params_mapping, + ) + from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM + from vllm.model_executor.models.gemma4_mm import Gemma4ForConditionalGeneration + + # Remove this shim when upstream vLLM Gemma4 MoE defines these natively. + Gemma4ForCausalLM.is_3d_moe_weight = True + Gemma4ForConditionalGeneration.is_3d_moe_weight = True + + if not hasattr(Gemma4ForCausalLM, "get_expert_mapping"): + + def get_causal_expert_mapping( + self: Any, + ) -> list[tuple[str, str, int, str]]: + return fused_moe_make_expert_params_mapping( + self.model, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=int(getattr(self.config, "num_experts", 0) or 0), + num_redundant_experts=0, + ) + + get_causal_expert_mapping.__art_patched__ = True # type: ignore[attr-defined] + Gemma4ForCausalLM.get_expert_mapping = get_causal_expert_mapping # type: ignore[attr-defined,method-assign] + + if not hasattr(Gemma4ForConditionalGeneration, "get_expert_mapping"): + + def get_conditional_expert_mapping( + self: Any, + ) -> list[tuple[str, str, int, str]]: + return self.language_model.get_expert_mapping() + + get_conditional_expert_mapping.__art_patched__ = True # type: ignore[attr-defined] + Gemma4ForConditionalGeneration.get_expert_mapping = ( # type: ignore[attr-defined,method-assign] + get_conditional_expert_mapping + ) diff --git a/vllm_runtime/src/art_vllm_runtime/lora_delta.py b/vllm_runtime/src/art_vllm_runtime/lora_delta.py new file mode 100644 index 000000000..4e0ee81a3 --- /dev/null +++ b/vllm_runtime/src/art_vllm_runtime/lora_delta.py @@ -0,0 +1,210 @@ +from collections.abc import Iterable +from contextlib import contextmanager +import math +from typing import Any + +import torch + +ART_LORA_DELTA_UPDATE_KIND = "lora_delta" +_LORA_A_SUFFIX = ".lora_A.weight" +_LORA_B_SUFFIX = ".lora_B.weight" +_GATE_UP_A_SUFFIX = ".base_layer.lora_A.weight" +_GATE_UP_B_SUFFIX = ".base_layer.lora_B.weight" +_PEFT_PREFIX = "base_model.model." + + +def _lora_scaling(adapter_config: dict[str, Any]) -> float: + rank = int(adapter_config["r"]) + alpha = float(adapter_config["lora_alpha"]) + return alpha / math.sqrt(rank) if adapter_config.get("use_rslora") else alpha / rank + + +def _checkpoint_base(base: str) -> str: + if base.startswith(_PEFT_PREFIX): + base = base.removeprefix(_PEFT_PREFIX) + return base.removesuffix(".base_layer") + + +def _lora_delta( + *, + a_key: str, + b_key: str, + lora_tensors: dict[str, torch.Tensor], + previous_lora_tensors: dict[str, torch.Tensor] | None, + scaling: float, +) -> torch.Tensor: + delta = lora_tensors[b_key].float().matmul(lora_tensors[a_key].float()) + delta.mul_(scaling) + if previous_lora_tensors is None: + return delta + previous_delta = ( + previous_lora_tensors[b_key] + .float() + .matmul(previous_lora_tensors[a_key].float()) + ) + return delta.sub_(previous_delta.mul_(scaling)) + + +def _unpack_expert_lora_b(tensor: torch.Tensor, *, rank: int) -> torch.Tensor: + num_experts = tensor.shape[1] // rank + return tensor.reshape(tensor.shape[0], rank, num_experts).permute(2, 0, 1) + + +def _iter_lora_checkpoint_deltas( + lora_tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + previous_lora_tensors: dict[str, torch.Tensor] | None, +) -> Iterable[tuple[str, torch.Tensor]]: + rank = int(adapter_config["r"]) + scaling = _lora_scaling(adapter_config) + consumed: set[str] = set() + for a_key in sorted(lora_tensors): + if a_key.endswith(_GATE_UP_A_SUFFIX): + prefix = a_key.removesuffix(_GATE_UP_A_SUFFIX) + b_key = prefix + _GATE_UP_B_SUFFIX + consumed.update((a_key, b_key)) + a_tensor = lora_tensors[a_key] + b_tensor = _unpack_expert_lora_b(lora_tensors[b_key], rank=rank) + previous_b = ( + _unpack_expert_lora_b(previous_lora_tensors[b_key], rank=rank) + if previous_lora_tensors is not None + else None + ) + checkpoint_prefix = _checkpoint_base(prefix) + for expert_id, b_expert in enumerate(b_tensor): + expert_a = a_tensor[expert_id * rank : (expert_id + 1) * rank] + delta = b_expert.float().matmul(expert_a.float()).mul_(scaling) + if previous_b is not None: + previous_a = previous_lora_tensors[a_key][ + expert_id * rank : (expert_id + 1) * rank + ] + delta.sub_( + previous_b[expert_id] + .float() + .matmul(previous_a.float()) + .mul_(scaling) + ) + gate_delta, up_delta = delta.chunk(2, dim=0) + yield f"{checkpoint_prefix}.{expert_id}.gate_proj.weight", gate_delta + yield f"{checkpoint_prefix}.{expert_id}.up_proj.weight", up_delta + continue + if not a_key.endswith(_LORA_A_SUFFIX): + continue + prefix = a_key.removesuffix(_LORA_A_SUFFIX) + b_key = prefix + _LORA_B_SUFFIX + consumed.update((a_key, b_key)) + if prefix.endswith(".experts"): + a_tensor = lora_tensors[a_key] + b_tensor = _unpack_expert_lora_b(lora_tensors[b_key], rank=rank) + previous_b = ( + _unpack_expert_lora_b(previous_lora_tensors[b_key], rank=rank) + if previous_lora_tensors is not None + else None + ) + checkpoint_prefix = _checkpoint_base(prefix) + for expert_id, b_expert in enumerate(b_tensor): + expert_a = a_tensor[expert_id * rank : (expert_id + 1) * rank] + delta = b_expert.float().matmul(expert_a.float()).mul_(scaling) + if previous_b is not None: + previous_a = previous_lora_tensors[a_key][ + expert_id * rank : (expert_id + 1) * rank + ] + delta.sub_( + previous_b[expert_id] + .float() + .matmul(previous_a.float()) + .mul_(scaling) + ) + yield f"{checkpoint_prefix}.{expert_id}.down_proj.weight", delta + continue + yield ( + f"{_checkpoint_base(prefix)}.weight", + _lora_delta( + a_key=a_key, + b_key=b_key, + lora_tensors=lora_tensors, + previous_lora_tensors=previous_lora_tensors, + scaling=scaling, + ), + ) + unexpected = sorted(set(lora_tensors) - consumed) + if unexpected: + raise RuntimeError(f"Unexpected LoRA tensor keys: {unexpected}") + + +def _default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + if param.numel() == 1 and loaded_weight.numel() == 1: + param.data.copy_(loaded_weight.view(param.shape)) + return + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) into parameter " + f"({param.size()})" + ) + param.data.copy_(loaded_weight) + + +def _additive_weight_loader(param: torch.Tensor, original_loader: Any) -> Any: + def load_delta( + loader_param: torch.Tensor, + loaded_weight: torch.Tensor, + *args: Any, + **kwargs: Any, + ) -> Any: + real_data = loader_param.data + scratch = torch.zeros_like(real_data) + loader_param.data = scratch + try: + result = original_loader(loader_param, loaded_weight, *args, **kwargs) + finally: + loader_param.data = real_data + if result is not False: + real_data.add_(scratch) + return result + + return load_delta + + +@contextmanager +def _additive_weight_loaders(model: Any) -> Any: + originals: list[tuple[torch.Tensor, bool, Any]] = [] + for param in model.parameters(): + has_loader = hasattr(param, "weight_loader") + original_loader = getattr(param, "weight_loader", _default_weight_loader) + originals.append((param, has_loader, original_loader)) + param.weight_loader = _additive_weight_loader(param, original_loader) # type: ignore[attr-defined] + try: + yield + finally: + for param, has_loader, original_loader in originals: + if has_loader: + param.weight_loader = original_loader # type: ignore[attr-defined] + else: + delattr(param, "weight_loader") + + +def apply_lora_delta_update( + *, + model: Any, + lora_tensors: dict[str, torch.Tensor], + adapter_config: dict[str, Any], + previous_lora_tensors: dict[str, torch.Tensor] | None, +) -> dict[str, torch.Tensor]: + if previous_lora_tensors is not None and set(lora_tensors) != set( + previous_lora_tensors + ): + raise RuntimeError( + "LoRA update key set changed: " + f"current={sorted(lora_tensors)} previous={sorted(previous_lora_tensors)}" + ) + with torch.no_grad(), _additive_weight_loaders(model): + model.load_weights( + _iter_lora_checkpoint_deltas( + lora_tensors, + adapter_config=adapter_config, + previous_lora_tensors=previous_lora_tensors, + ) + ) + return { + name: tensor.detach().clone() for name, tensor in sorted(lora_tensors.items()) + } diff --git a/vllm_runtime/src/art_vllm_runtime/patches.py b/vllm_runtime/src/art_vllm_runtime/patches.py index 53e098cd2..91f0298b2 100644 --- a/vllm_runtime/src/art_vllm_runtime/patches.py +++ b/vllm_runtime/src/art_vllm_runtime/patches.py @@ -11,17 +11,25 @@ def apply_vllm_runtime_patches() -> None: + from art_vllm_runtime.gemma4_moe_lora_patch import ( + patch_gemma4_moe_lora_support, + ) + patch_transformers_v5_compat() + patch_gemma4_moe_lora_support() subclass_chat_completion_request() patch_listen_for_disconnect() patch_tool_parser_manager() patch_nccl_unique_id_bootstrap() + patch_art_lora_delta_weight_update() + patch_gemma4_checkpoint_weight_update_reload() patch_routed_experts_prefix_cache_sidecar() def patch_transformers_v5_compat() -> None: _patch_rope_validation_ignore_keys() _patch_qwen3_vl_moe_tie_word_embeddings() + _patch_gemma4_moe_experts_per_tok_alias() def _patch_rope_validation_ignore_keys() -> None: @@ -50,6 +58,20 @@ def _patch_qwen3_vl_moe_tie_word_embeddings() -> None: setattr(Qwen3VLMoeTextConfig, "tie_word_embeddings", False) +def _patch_gemma4_moe_experts_per_tok_alias() -> None: + from transformers import Gemma4TextConfig + + if hasattr(Gemma4TextConfig, "num_experts_per_tok"): + return + + def num_experts_per_tok(self: Any) -> Any: + # vLLM's routed-expert sidecar uses the Mistral MoE field name, while + # HF Gemma4 stores the same router top-k value as top_k_experts. + return self.top_k_experts + + Gemma4TextConfig.num_experts_per_tok = property(num_experts_per_tok) # type: ignore[attr-defined] + + def subclass_chat_completion_request() -> None: from vllm.entrypoints.openai.chat_completion import protocol @@ -69,9 +91,9 @@ def __init__(self, *args: object, **kwargs: object) -> None: def patch_listen_for_disconnect() -> None: - import vllm.entrypoints.utils + from vllm.entrypoints.serve.utils import api_utils - if getattr(vllm.entrypoints.utils, "_art_listen_for_disconnect_patched", False): + if getattr(api_utils, "_art_listen_for_disconnect_patched", False): return async def patched_listen_for_disconnect(request: Any) -> None: @@ -79,12 +101,16 @@ async def patched_listen_for_disconnect(request: Any) -> None: while True: message = await request.receive() if message["type"] == "http.disconnect": + if getattr( + request.app.state, "enable_server_load_tracking", False + ) and hasattr(request.app.state, "server_load_metrics"): + request.app.state.server_load_metrics -= 1 break except UnboundLocalError: pass - vllm.entrypoints.utils.listen_for_disconnect = patched_listen_for_disconnect # ty:ignore[invalid-assignment] - setattr(vllm.entrypoints.utils, "_art_listen_for_disconnect_patched", True) + api_utils.listen_for_disconnect = patched_listen_for_disconnect # ty:ignore[invalid-assignment] + setattr(api_utils, "_art_listen_for_disconnect_patched", True) def patch_tool_parser_manager() -> None: @@ -170,6 +196,124 @@ def patched_comm_init_rank( NCCLLibrary.ncclCommInitRank = patched_comm_init_rank # type: ignore[method-assign] +def _is_gemma4_conditional_worker(worker: Any) -> bool: + hf_config = worker.model_config.hf_config + return hf_config.architectures == ["Gemma4ForConditionalGeneration"] + + +def patch_gemma4_checkpoint_weight_update_reload() -> None: + from vllm.v1.worker.gpu_worker import Worker + + original_start_weight_update = Worker.start_weight_update + if getattr(original_start_weight_update, "__art_patched__", False): + return + original_finish_weight_update = Worker.finish_weight_update + + def start_weight_update( + self: Any, + is_checkpoint_format: bool = True, + ) -> None: + if not is_checkpoint_format or not _is_gemma4_conditional_worker(self): + return original_start_weight_update( + self, + is_checkpoint_format=is_checkpoint_format, + ) + self._check_weight_transfer_engine() + if self._weight_update_active: + raise RuntimeError( + "start_weight_update called while a weight update is " + "already active. Call finish_weight_update first." + ) + # vLLM's layerwise checkpoint reload corrupts Gemma4 after reloading + # the original checkpoint. Direct model.load_weights keeps the update + # path identical to initial checkpoint loading while preserving the + # streaming NCCL transfer used by ART merged serving. + self._is_checkpoint_format = True + self._weight_update_active = True + + def finish_weight_update(self: Any) -> None: + if not _is_gemma4_conditional_worker(self): + return original_finish_weight_update(self) + self._check_weight_transfer_engine() + if not self._weight_update_active: + raise RuntimeError( + "start_weight_update must be called before finish_weight_update." + ) + if not self._is_checkpoint_format: + return original_finish_weight_update(self) + self._weight_update_active = False + self._is_checkpoint_format = True + + start_weight_update.__art_patched__ = True # type: ignore[attr-defined] + start_weight_update.__art_original__ = original_start_weight_update # type: ignore[attr-defined] + finish_weight_update.__art_patched__ = True # type: ignore[attr-defined] + finish_weight_update.__art_original__ = original_finish_weight_update # type: ignore[attr-defined] + Worker.start_weight_update = start_weight_update # type: ignore[method-assign] + Worker.finish_weight_update = finish_weight_update # type: ignore[method-assign] + + +def patch_art_lora_delta_weight_update() -> None: + import torch + from vllm.v1.worker.gpu_worker import Worker + + from art_vllm_runtime.lora_delta import ( + ART_LORA_DELTA_UPDATE_KIND, + apply_lora_delta_update, + ) + + original_update_weights = Worker.update_weights + if getattr(original_update_weights, "__art_lora_delta_patched__", False): + return + + def update_weights(self: Any, update_info: dict) -> None: + if update_info.get("art_weight_update_kind") != ART_LORA_DELTA_UPDATE_KIND: + return original_update_weights(self, update_info) + + self._check_weight_transfer_engine() + assert self.weight_transfer_engine is not None + if not self._weight_update_active: + raise RuntimeError( + "start_weight_update must be called before update_weights." + ) + + adapter_config = update_info["art_lora_config"] + transfer_update_info = dict(update_info) + del transfer_update_info["art_weight_update_kind"] + del transfer_update_info["art_lora_config"] + typed_update_info = self.weight_transfer_engine.parse_update_info( + transfer_update_info + ) + lora_tensors: dict[str, torch.Tensor] = {} + + def collect_lora_tensors(weights: list[tuple[str, torch.Tensor]]) -> None: + for name, tensor in weights: + if name in lora_tensors: + raise RuntimeError(f"Duplicate LoRA tensor in update: {name}") + lora_tensors[name] = tensor.detach().contiguous().clone() + + with torch.device(self.device): + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=collect_lora_tensors, + ) + self._art_previous_lora_tensors = apply_lora_delta_update( + model=self.model_runner.model, + lora_tensors=lora_tensors, + adapter_config=adapter_config, + previous_lora_tensors=getattr( + self, + "_art_previous_lora_tensors", + None, + ), + ) + + torch.accelerator.synchronize() + + update_weights.__art_lora_delta_patched__ = True # type: ignore[attr-defined] + update_weights.__art_original__ = original_update_weights # type: ignore[attr-defined] + Worker.update_weights = update_weights # type: ignore[method-assign] + + def _lora_cache_key(lora_request: Any) -> tuple[Any, ...]: if lora_request is None: return () @@ -228,6 +372,13 @@ def patch_routed_experts_prefix_cache_sidecar() -> None: if getattr(routed_experts_capturer, "_art_prefix_route_sidecar_patched", False): return + if hasattr(routed_experts_capturer, "RoutedExpertsManager"): + # vLLM 0.23 stores routed experts by physical KV-cache slot, so prefix + # cache hits recover routes from the shared slot buffer without ART's + # old per-request host-cache sidecar. + setattr(routed_experts_capturer, "_art_prefix_route_sidecar_patched", True) + return + host_cls = routed_experts_capturer._RoutedExpertsHostCache capturer_cls = routed_experts_capturer._RoutedExpertsCapturerReal diff --git a/vllm_runtime/uv.lock b/vllm_runtime/uv.lock index 1956cd581..f647be079 100644 --- a/vllm_runtime/uv.lock +++ b/vllm_runtime/uv.lock @@ -4,7 +4,7 @@ requires-python = "==3.12.*" [manifest] overrides = [ - { name = "flashinfer-python", specifier = "==0.6.8.post1" }, + { name = "flashinfer-python", specifier = "==0.6.12" }, { name = "numpy", specifier = "<2" }, { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'", specifier = "==2.28.9" }, { name = "torch", url = "https://download.pytorch.org/whl/test/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, @@ -143,7 +143,7 @@ dependencies = [ requires-dist = [ { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'", specifier = "==2.28.9" }, { name = "transformers", specifier = "==5.6.2" }, - { name = "vllm", marker = "sys_platform == 'linux'", url = "https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "sys_platform == 'linux'", url = "https://github.com/vllm-project/vllm/releases/download/v0.23.0/vllm-0.23.0%2Bcu129-cp38-abi3-manylinux_2_28_x86_64.whl" }, ] [[package]] @@ -282,7 +282,7 @@ wheels = [ [[package]] name = "compressed-tensors" -version = "0.15.0.1" +version = "0.17.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "loguru" }, @@ -290,9 +290,9 @@ dependencies = [ { name = "torch" }, { name = "transformers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/1b/c3c4a98ec5f2727656336f07a0c35862195c310d8eb0b2fa5b4be6848680/compressed_tensors-0.15.0.1.tar.gz", hash = "sha256:a8e93054e8a5ec49c980b09ed36c4c1249b4a8ee167920a8e461c4da26e78d99", size = 229412, upload-time = "2026-04-10T14:23:54.708Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/9e/d7f18bd9a0354088abc11a0c1f2c7698f7c49e5a709faedf6a46e388f693/compressed_tensors-0.17.0.tar.gz", hash = "sha256:15c20d06bdbcf35b51fc99fd125e7b9be1e1855567c33b7a46dfac26ad6fb126", size = 257091, upload-time = "2026-06-03T16:49:17.208Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/52/93833dc1610e017ac5b7dcd59b8304d8ef67d1114c2d124e728a2cbbea12/compressed_tensors-0.15.0.1-py3-none-any.whl", hash = "sha256:e1b1f322e82e475715e242bad46925a304ea8e5c98b5055a15b8eb22fb6bfea9", size = 194260, upload-time = "2026-04-10T14:23:53.098Z" }, + { url = "https://files.pythonhosted.org/packages/35/63/6edf0415b072fff0bf8b546074dea3f0f9b148e49b601ac98bdc60a76c68/compressed_tensors-0.17.0-py3-none-any.whl", hash = "sha256:4a1b89b508f7efb8ffb4eee8a6e69e0452d9b080cae130146025c64fbe9fa9aa", size = 211714, upload-time = "2026-06-03T16:49:15.672Z" }, ] [[package]] @@ -371,6 +371,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/76/84cb68be463c827bf79da9fa0aa5140838de6455ef6f438bbe0ffa75d378/cuda_tile-1.3.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:e4865acbff1172aaee304bf9c550586088d8b4545a384423597a590899386709", size = 247301, upload-time = "2026-04-20T15:51:04.042Z" }, ] +[package.optional-dependencies] +tileiras = [ + { name = "nvidia-cuda-nvcc" }, + { name = "nvidia-cuda-tileiras" }, + { name = "nvidia-nvvm" }, +] + [[package]] name = "cuda-toolkit" version = "12.8.1" @@ -582,14 +589,15 @@ wheels = [ [[package]] name = "fastsafetensors" -version = "0.3.1" +version = "0.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d2/69/e34a1e86a02b255896c57263bf0dfbae45b4708fd609b937f783c2202e7b/fastsafetensors-0.3.1.tar.gz", hash = "sha256:b7eb039a564d77280d17e5d63b27e9963ba5158ad02d2a3c1772c62072a81a53", size = 55665, upload-time = "2026-05-06T08:48:59.125Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c8/33/c97b2bcbe06e0f011eedee0f41d4060f6344901a53c2703acc3dd7429713/fastsafetensors-0.3.2.tar.gz", hash = "sha256:9e358fce238684613a5c3ebb7800c52c5b3270c0bb5e4ed2191ee8f3d0431de1", size = 70409, upload-time = "2026-05-22T05:39:34.787Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/50/909871d673bacd6dfc7fee5e59bcd4ec9fbd19775bafe567ad236a3adced/fastsafetensors-0.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac76f33e47959b7c31658fbbda1805df7540819828a3ce6a94eb34b4db0b1fa7", size = 1854825, upload-time = "2026-05-06T08:48:54.452Z" }, + { url = "https://files.pythonhosted.org/packages/c9/bb/9f821eac9bddd41ea1c5cd9b6a597c002741f022ecf6f3ba5cfcc3e9c950/fastsafetensors-0.3.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:69f4d8cbd3b542e5ddf7fee8136cf35e1524f9c30e118f64a0e846dab7e8de6b", size = 1877989, upload-time = "2026-06-04T09:02:56.11Z" }, + { url = "https://files.pythonhosted.org/packages/e9/68/a31c1661adf4d1b5ec29470ff991bde9094e4f347b0e6d1af8ba6b560d32/fastsafetensors-0.3.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6a932d7166c9e17e48aca3e5503d326bc6fc73fce6dc985ae6bd2ccc0f308b14", size = 1907188, upload-time = "2026-05-22T05:39:30.242Z" }, ] [[package]] @@ -603,20 +611,20 @@ wheels = [ [[package]] name = "flashinfer-cubin" -version = "0.6.8.post1" +version = "0.6.12" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/11/b7/5e3b1a8c67031b421a8bd29c2bc29b900a550bb3392e8bda18bb15b5e476/flashinfer_cubin-0.6.8.post1-py3-none-any.whl", hash = "sha256:43636d4cd39e694a83d76a89f87fefcdf4cecb4c4f7dd22dac25ec368c1e901f", size = 295154113, upload-time = "2026-04-18T18:28:21.738Z" }, + { url = "https://files.pythonhosted.org/packages/7d/c6/63b1bb7b1a7ae612ecf53c0e568312c3d004f9f7558b0ab5edcf7900c360/flashinfer_cubin-0.6.12-py3-none-any.whl", hash = "sha256:01de132c493bb21d5df42ebe6890966cf83b40aa970dae06b2a3c0bed85f13ec", size = 447533460, upload-time = "2026-05-29T23:45:27.579Z" }, ] [[package]] name = "flashinfer-python" -version = "0.6.8.post1" +version = "0.6.12" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "apache-tvm-ffi" }, { name = "click" }, - { name = "cuda-tile" }, + { name = "cuda-tile", extra = ["tileiras"] }, { name = "einops" }, { name = "ninja" }, { name = "numpy" }, @@ -629,9 +637,9 @@ dependencies = [ { name = "torch" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/1e/2760fef9e74abc4480961048e5790b4c9e955872fb4d7d97900cfddced5a/flashinfer_python-0.6.8.post1.tar.gz", hash = "sha256:b18e4121baf9b93fa9a9f368ba9b981a0342895f50ab9dddc224aeb964ed346f", size = 6675885, upload-time = "2026-04-18T18:28:13.299Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/d0/114a64319f5a804def2f307d5ed8f95e6d94a2acdacac4ed5f57525cbf46/flashinfer_python-0.6.12.tar.gz", hash = "sha256:bed67f9c46d81dd22611dfef2787998fc412b2fe2648d9e7d336861dda912694", size = 9453326, upload-time = "2026-05-29T23:45:16.466Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/6d/1e8a8533913e33a50a486332ce0673f4fdb860f6eb9ed450327c5c1762cb/flashinfer_python-0.6.8.post1-py3-none-any.whl", hash = "sha256:818f9b8cc2fe66c42a1f6264be4841ac8821ada703685a02cfccb2b5124a710b", size = 9385316, upload-time = "2026-04-18T18:28:10.285Z" }, + { url = "https://files.pythonhosted.org/packages/85/26/3ca33edbf64906603633cb91904798e427c0ac1c55a13707f8081708f3ae/flashinfer_python-0.6.12-py3-none-any.whl", hash = "sha256:0c7a01e586b4796810d974cbf13a9c0eb2ade6a94d12e3220cf7782a1c09b8d3", size = 13985243, upload-time = "2026-05-29T23:45:13.477Z" }, ] [[package]] @@ -801,6 +809,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/d4/e33bf0b362810a9b96c5923e38908950d58ecb512db42e3730320c7f4a3a/huggingface_hub-1.9.2-py3-none-any.whl", hash = "sha256:e1e62ce237d4fbeca9f970aeb15176fbd503e04c25577bfd22f44aa7aa2b5243", size = 637349, upload-time = "2026-04-08T08:43:09.114Z" }, ] +[[package]] +name = "humming-kernels" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-bindings" }, + { name = "jinja2" }, + { name = "numpy" }, + { name = "nvidia-ml-py" }, + { name = "pyelftools" }, + { name = "safetensors" }, + { name = "tabulate" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "triton" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/f6/05e95b66cca48def9db0d6c40374fe285c7d9c913fe126030bcfb7cb3088/humming_kernels-0.1.4.tar.gz", hash = "sha256:fdaf4f23cc6b03bb1be3fd24aa11dc7798881e5448826e2404b4f12d8096f0d0", size = 117555, upload-time = "2026-06-04T03:24:03.504Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/16/d9318061a560305034e14cb7bf6483ffc8735eff6b30f260907dbbd4e85d/humming_kernels-0.1.4-py3-none-any.whl", hash = "sha256:c85094cd7cf8cdd959c5e2f7f239a7d72a7640ec1f948787434bc06e24e9ed00", size = 161312, upload-time = "2026-06-04T03:24:01.897Z" }, +] + +[package.optional-dependencies] +cu12 = [ + { name = "nvidia-cuda-cccl-cu12" }, + { name = "nvidia-cuda-nvcc-cu12" }, + { name = "nvidia-cuda-nvrtc-cu12" }, + { name = "nvidia-cuda-runtime-cu12" }, +] + [[package]] name = "idna" version = "3.11" @@ -922,12 +959,12 @@ wheels = [ [[package]] name = "llguidance" -version = "1.3.0" +version = "1.7.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/48/3f7a9d3ff1b36bba92b5107a3a21286821227afe9ea464736133994d61fb/llguidance-1.3.0.tar.gz", hash = "sha256:861249afd51dc325646834462ea827e57a5c2b2042e108e6aae7059fdad9104d", size = 1070460, upload-time = "2025-10-20T19:58:44.164Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/91/6bc8bb503dc259e46d253b5424385a54fe06c38a4c7a12befe69a3c2455a/llguidance-1.7.6.tar.gz", hash = "sha256:db7febbe412ed2015501904646750071d7e00e6df7f85c4b956ad4f206fd2df7", size = 1156574, upload-time = "2026-06-03T20:13:25.316Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/11/44389d3d1526d7a5c38ffd587a5ebc61d7bee443ac1dea95f2089ad58f5f/llguidance-1.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f6caca5d78db7f76e1fbb0fff8607b861c32d47fa3d5dee2fc49de27ee269df", size = 2835242, upload-time = "2025-10-20T19:58:34.518Z" }, - { url = "https://files.pythonhosted.org/packages/83/a8/1ff2bedb8f9acb46a2d2d603415d272bb622c142ea86f5b95445cc6e366c/llguidance-1.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc17e9dd602c3879bf91664a64bf72f54c74dbfbeb24ccfab6a5fe435b12f7aa", size = 3033133, upload-time = "2025-10-20T19:58:38.721Z" }, + { url = "https://files.pythonhosted.org/packages/51/b9/dc76d7716e04dc7b3427cae52eaa32bd20771382d4d1dd9f4538a9dd2086/llguidance-1.7.6-cp39-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:e70fa25ed550c2b50c2fd70baa9e2808b4ecb859d01e453bd5459aff62ba38c3", size = 2899993, upload-time = "2026-06-03T20:13:13.563Z" }, + { url = "https://files.pythonhosted.org/packages/1a/64/d74336f22242ef94356a456057d4ff1be7c1bc9c7dbc867171c6982a5512/llguidance-1.7.6-cp39-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:ceec951d29a74309984e3be0fe7f5f56c1362434cd937abd517b259a60908b1e", size = 3074809, upload-time = "2026-06-03T20:13:15.498Z" }, ] [[package]] @@ -1028,7 +1065,7 @@ wheels = [ [[package]] name = "mistral-common" -version = "1.11.2" +version = "1.11.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonschema" }, @@ -1040,9 +1077,9 @@ dependencies = [ { name = "tiktoken" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c2/eb/12167a1bea9714582e5b4f539f9c019323363e314a499c72855ff0e5ad43/mistral_common-1.11.2.tar.gz", hash = "sha256:79f68fc2d1190f28637f40e053f919c8c2697e00b2aa679ddee562a95183f4ad", size = 6357845, upload-time = "2026-05-04T19:47:40.413Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/03/3c5d4c9430da406f8444f9a7b058a6aa89c525fb068a57fe2ab8b04a6d08/mistral_common-1.11.3.tar.gz", hash = "sha256:6437e128fc8a307318440839ca14ddf2e8060056b062233ec0db10352651374c", size = 6360629, upload-time = "2026-06-04T09:01:11.131Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/47/f0/6a5d604b972e442b9d36c117d01788feddad099e4965699e3516ee6fefc3/mistral_common-1.11.2-py3-none-any.whl", hash = "sha256:ebb42062cd705a0aa2bc69b4cde2b83d446ae58150b7e29322c90cb08fcfca6c", size = 6531968, upload-time = "2026-05-04T19:47:37.718Z" }, + { url = "https://files.pythonhosted.org/packages/7b/76/dbfdf9c59e2a4b0116587626a3768c2a3b2ba1758b5756743918c2337fdc/mistral_common-1.11.3-py3-none-any.whl", hash = "sha256:dbfcef9d0c892727ee08a080f0c1039baed5430b291f5425ffd88892bf09e52c", size = 6533154, upload-time = "2026-06-04T09:01:14.186Z" }, ] [package.optional-dependencies] @@ -1193,6 +1230,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, ] +[[package]] +name = "nvidia-cuda-cccl-cu12" +version = "12.9.27" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/7e/82e49956b046bdc506c789235c587d9b3ef58b8bc1782258c1e247229647/nvidia_cuda_cccl_cu12-12.9.27-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7898b38aa68beaa234d48f0868273702342a196d6e2e9d0ef058dca2390ebea", size = 3152245, upload-time = "2025-05-01T19:32:04.802Z" }, + { url = "https://files.pythonhosted.org/packages/18/2a/d4cd8506d2044e082f8cd921be57392e6a9b5ccd3ffdf050362430a3d5d5/nvidia_cuda_cccl_cu12-12.9.27-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:37869e17ce2e1ecec6eddf1927cca0f8c34e64fd848d40453df559091e2d7117", size = 3152243, upload-time = "2025-05-01T19:32:13.955Z" }, +] + +[[package]] +name = "nvidia-cuda-crt" +version = "13.3.33" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/32/5ea57f8cd6ad5df2173d175ac5db4e06edde40028b1b1f6c539ea4c10290/nvidia_cuda_crt-13.3.33-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c8c257393f9c9146a85d3644f352be8154843d760031f756e673222c768a4930", size = 157348, upload-time = "2026-05-26T16:28:40.446Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a7/998af901511d5efdc6e42fc597d32a69f34eecf86f1591a9d230ab3ab951/nvidia_cuda_crt-13.3.33-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ff37600c7b880a14cab4ade763b4c10c0ff92f25cc9dca30f0881ce52693c4", size = 157350, upload-time = "2026-05-26T16:29:22.315Z" }, +] + [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.8.90" @@ -1202,6 +1257,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, ] +[[package]] +name = "nvidia-cuda-nvcc" +version = "13.2.78" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cuda-crt" }, + { name = "nvidia-cuda-runtime" }, + { name = "nvidia-nvvm" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/df/faf551572ae1359290afa5cb05d2c4b7e6674b07b8283b20eab4dbad15f6/nvidia_cuda_nvcc-13.2.78-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:dfc76950c775cd00ce588f15192f08c9b858c0dcfa7da685acf39a3d0d8f588b", size = 38713559, upload-time = "2026-04-13T09:42:17.478Z" }, + { url = "https://files.pythonhosted.org/packages/65/0f/c7c7d538c61794130e759ad74710ab5aa8cab1f700ee1754381f8c665605/nvidia_cuda_nvcc-13.2.78-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c3bd144dd9b6b25e062589acb7bbd43d93d3120c72fad71da808f9817aba1239", size = 44040318, upload-time = "2026-04-13T09:42:50.457Z" }, +] + +[[package]] +name = "nvidia-cuda-nvcc-cu12" +version = "12.9.86" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/48/b54a06168a2190572a312bfe4ce443687773eb61367ced31e064953dd2f7/nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0", size = 40546229, upload-time = "2025-06-05T20:01:53.357Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5c/8cc072436787104bbbcbde1f76ab4a0d89e68f7cebc758dd2ad7913a43d0/nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b", size = 39411138, upload-time = "2025-06-05T20:01:43.182Z" }, +] + [[package]] name = "nvidia-cuda-nvrtc-cu12" version = "12.8.93" @@ -1211,6 +1289,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" }, ] +[[package]] +name = "nvidia-cuda-runtime" +version = "13.3.29" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/e5/c1a221c8e6fecd071b80ea44c20fc253ae24f56e15e3f77cfbc3fb76e724/nvidia_cuda_runtime-13.3.29-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73291e19c9dd919c140c91bda2f80b0eca487da5ee30a086ef7bc4918ecb90ea", size = 2356574, upload-time = "2026-05-26T16:29:56.333Z" }, + { url = "https://files.pythonhosted.org/packages/97/be/5699b6e642b372f7d24c59c2f41383e2696825e20bab85f7399c7c6a56f7/nvidia_cuda_runtime-13.3.29-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e04420616e72f563167a7733272992d7e6df6dc5cb54b2f94f9f1520ea9e30c1", size = 2339786, upload-time = "2026-05-26T16:30:21.584Z" }, +] + [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.8.90" @@ -1220,6 +1307,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, ] +[[package]] +name = "nvidia-cuda-tileiras" +version = "13.2.78" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cuda-nvcc" }, + { name = "nvidia-nvvm" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/04/eb26cc1d67c653f5dbe8c13fd6da9c1e844b097147051b5052ac5e6d4047/nvidia_cuda_tileiras-13.2.78-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:658299efca52a20496b425efb0b19cb1ea7d57406a18d3f5024d4df92d5b54c1", size = 36418791, upload-time = "2026-04-13T09:48:30.107Z" }, + { url = "https://files.pythonhosted.org/packages/7f/b8/c8a96862268943c7cf30a014fe2d8f70c651d30fbfa790d54c3e347b6fa1/nvidia_cuda_tileiras-13.2.78-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce7c140a518aa8dfe033e7176f593617ed2fece0e50331e2a14dafd236723fd", size = 36970479, upload-time = "2026-04-13T09:48:49.919Z" }, +] + [[package]] name = "nvidia-cudnn-cu12" version = "9.19.0.56" @@ -1234,11 +1334,11 @@ wheels = [ [[package]] name = "nvidia-cudnn-frontend" -version = "1.18.0" +version = "1.25.0" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/b4/604e230378680ee117849a4e1045baca092f93161a829291a84d5acce70c/nvidia_cudnn_frontend-1.18.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:310b417f2848a83d1437203fcaeea320a74fb7f28af20bf42bf5afc9c01f1c12", size = 2027408, upload-time = "2026-01-27T23:32:46.576Z" }, - { url = "https://files.pythonhosted.org/packages/c6/52/08f98262e77b1cbcc834cc1a5db494d0661ea1dbdea58c2e2d51a57fdaca/nvidia_cudnn_frontend-1.18.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c023539ca6de99234cf5102c3ec0d6af817f5396fc93028a22ba5b834a35b8a", size = 2159245, upload-time = "2026-01-27T23:07:32.664Z" }, + { url = "https://files.pythonhosted.org/packages/28/0f/df39a194f2529093db737d43cc4cbf594c6a79712a09aa104b999e4d95d4/nvidia_cudnn_frontend-1.25.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09e6e1bc48ce1235743f89d8ea699c52b3008fd6dae7f2ecadb744bebf272a2b", size = 3263306, upload-time = "2026-06-10T21:07:48.093Z" }, + { url = "https://files.pythonhosted.org/packages/03/65/3b45941d8a22128b971e910f2e9af6bf5ef453e92cc329c56b6eb53c53de/nvidia_cudnn_frontend-1.25.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a94a72d736bd79eb35f451aaf26d9493778e02ecabccc92c05425508c9e7a83", size = 3414884, upload-time = "2026-06-10T21:08:08.603Z" }, ] [[package]] @@ -1308,18 +1408,18 @@ wheels = [ [[package]] name = "nvidia-cutlass-dsl" -version = "4.4.2" +version = "4.5.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cutlass-dsl-libs-base" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/03/678dab0383db1ddfc449da216220f40404189eb36eeed9d87a4fa4bdb0e6/nvidia_cutlass_dsl-4.4.2-py3-none-any.whl", hash = "sha256:7cfb9ef19062b055b9372c7a627004724e2755e4c8b16c3cc88807d64501a4ae", size = 10167, upload-time = "2026-03-16T02:18:59.043Z" }, + { url = "https://files.pythonhosted.org/packages/f0/15/575d7df4fe2f3406f1cfc68be72aeff2834f8a696daf1cd5bee8017e4507/nvidia_cutlass_dsl-4.5.2-py3-none-any.whl", hash = "sha256:68ed1b63ca74aae87955012da9dfd7fdaae471329d0028b229b841c7192ccf52", size = 10179, upload-time = "2026-05-25T03:38:56.364Z" }, ] [[package]] name = "nvidia-cutlass-dsl-libs-base" -version = "4.4.2" +version = "4.5.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cuda-python" }, @@ -1327,8 +1427,8 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/7d/0df5e38d11e52cc72095a14d6448bc1c5d0d4b00b069a1189ca417fb225b/nvidia_cutlass_dsl_libs_base-4.4.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2ec8812eeadcbb6fe20bda2e295ed9c00653f8253b78e33cf0ab65a47b829e73", size = 75473821, upload-time = "2026-03-16T02:27:08.371Z" }, - { url = "https://files.pythonhosted.org/packages/56/98/e264964741d9cc9816625d9600d17a5249fd5cbd8c2d166fb0d0c34dfe5a/nvidia_cutlass_dsl_libs_base-4.4.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:22e37b58f7a6f2f43bba533c4df8a088012122e0b4e9a632eca23937adeafb39", size = 74355593, upload-time = "2026-03-16T02:25:11.762Z" }, + { url = "https://files.pythonhosted.org/packages/b1/ef/e827e3c67d72adbf4e8f680bdf03b1b67723d9e1ae7c3d0a1751f39f69ce/nvidia_cutlass_dsl_libs_base-4.5.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d2a3c412287e356fbe48fe9f845d6d33cd35dea5e20d7e4f628c20957967cacd", size = 75643473, upload-time = "2026-05-25T03:49:15.857Z" }, + { url = "https://files.pythonhosted.org/packages/97/68/c1247ab848f26c4ab56e562eea0e3f31fc14c9aaf0d883afaa92d8f05592/nvidia_cutlass_dsl_libs_base-4.5.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15ef6a59193667e663934ef4873f8ccad37455e9b7c3c419c3072113b8aedf61", size = 74513226, upload-time = "2026-05-25T03:51:32.496Z" }, ] [[package]] @@ -1376,6 +1476,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, ] +[[package]] +name = "nvidia-nvvm" +version = "13.2.78" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/1f/930d63ccc8adcdf27bfc051a24e3e4da2cf6ef987848d6d1d642e29d704b/nvidia_nvvm-13.2.78-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:f5aa433631109bbdec81802c5b5f319bf10bc891fe2f212e4e445845211d6f77", size = 64279462, upload-time = "2026-04-13T10:02:25.719Z" }, + { url = "https://files.pythonhosted.org/packages/8b/fd/db44b7a662a6af75a9a0683ca4580c855a3f5fcfdf1261b0ddb9fce0ee26/nvidia_nvvm-13.2.78-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88075f87a361a1dce95c799cabc028f7093af616a5702dcfb74eba4045dbbd5f", size = 61886055, upload-time = "2026-04-13T10:02:00.345Z" }, +] + [[package]] name = "openai" version = "2.24.0" @@ -1786,6 +1895,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, ] +[[package]] +name = "pyelftools" +version = "0.33" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/11/767522582afab1b884d277de0e6e011640cb9d7292a38694b4b1a1df1ae8/pyelftools-0.33.tar.gz", hash = "sha256:660d82dcbeb8e83d1702bd97f223f761625da06111c0cc988eac6b8ab0c1b61f", size = 15068655, upload-time = "2026-05-29T12:56:22.553Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/2a/f9697576603dae937727827505a6126a066affb227034e77e6f9068910da/pyelftools-0.33-py3-none-any.whl", hash = "sha256:f215ad5f47d3f1373a21496a6c9e0707c622840d0622f23ff7ce08678b020036", size = 201178, upload-time = "2026-05-29T12:56:20.587Z" }, +] + [[package]] name = "pygments" version = "2.20.0" @@ -2218,6 +2336,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/f4/0de46cfa12cdcbcd464cc59fde36912af405696f687e53a091fb432f694c/tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc", size = 2612133, upload-time = "2026-01-05T10:45:17.232Z" }, ] +[[package]] +name = "tokenspeed-mla" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi" }, + { name = "nvidia-cutlass-dsl" }, + { name = "tokenspeed-triton" }, + { name = "torch" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/20/4110d624d81d63f0bee2f19dba7ea0e1d8a31ea50147e6c1db82223c88a4/tokenspeed_mla-0.1.2-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:592590f36d85e624ecdc5e357ff35e29e761e6d879900dce8b67a6785c8ce75c", size = 743769, upload-time = "2026-05-13T03:30:54.486Z" }, + { url = "https://files.pythonhosted.org/packages/84/01/4bf8b74ead3e8e7c1c809435396254c067a33fde48acc20f602aae622d97/tokenspeed_mla-0.1.2-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:c9466a351fe039792e56cf49f3e79744c1dc28c7af10306a02e62b8e92fa5985", size = 748681, upload-time = "2026-05-13T03:30:56.718Z" }, +] + +[[package]] +name = "tokenspeed-triton" +version = "3.7.10.post20260531" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/58/fdb5fb70d99c1f18f01c2198420fa2a0f7e5301bd7dd5b5f34b22a3cb87b/tokenspeed_triton-3.7.10.post20260531-cp312-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:16cd0a3fc1cffeb458a7e03e8688714f49fdf0b5a108bfca999f46597c3faabb", size = 81636010, upload-time = "2026-05-31T01:29:17.699Z" }, + { url = "https://files.pythonhosted.org/packages/d7/49/7bae94729bfd7a3f331795251302f0b0c8e54a7ec25b3af5d5bfe133367c/tokenspeed_triton-3.7.10.post20260531-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b90ac41e7f15933797545ff1a9e803a9d8beb4ca9ba70f6d41a9e0fc26484f5c", size = 85888791, upload-time = "2026-05-31T01:29:25.584Z" }, +] + [[package]] name = "torch" version = "2.11.0+cu128" @@ -2432,8 +2574,8 @@ wheels = [ [[package]] name = "vllm" -version = "0.20.2rc1.dev168+gecd0b60aa.cu129" -source = { url = "https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" } +version = "0.23.0+cu129" +source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.23.0/vllm-0.23.0%2Bcu129-cp38-abi3-manylinux_2_28_x86_64.whl" } dependencies = [ { name = "aiohttp" }, { name = "anthropic" }, @@ -2452,6 +2594,7 @@ dependencies = [ { name = "flashinfer-cubin" }, { name = "flashinfer-python" }, { name = "gguf" }, + { name = "humming-kernels", extra = ["cu12"] }, { name = "ijson" }, { name = "lark" }, { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'" }, @@ -2488,6 +2631,7 @@ dependencies = [ { name = "quack-kernels" }, { name = "regex" }, { name = "requests" }, + { name = "safetensors" }, { name = "sentencepiece" }, { name = "setproctitle" }, { name = "setuptools" }, @@ -2495,6 +2639,7 @@ dependencies = [ { name = "tiktoken" }, { name = "tilelang" }, { name = "tokenizers" }, + { name = "tokenspeed-mla" }, { name = "torch" }, { name = "torchaudio" }, { name = "torchvision" }, @@ -2505,7 +2650,7 @@ dependencies = [ { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'" }, ] wheels = [ - { url = "https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:ffc821955e01472615540047d585a5264b6cdc64b21b9273bbb9db18ee0c539d" }, + { url = "https://github.com/vllm-project/vllm/releases/download/v0.23.0/vllm-0.23.0%2Bcu129-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:8bc2203995d061e6b988916b71b9dee8a5970f5fdc5f37d4445a877a2fab2cc1" }, ] [package.metadata] @@ -2518,35 +2663,36 @@ requires-dist = [ { name = "cachetools" }, { name = "cbor2" }, { name = "cloudpickle" }, - { name = "compressed-tensors", specifier = "==0.15.0.1" }, + { name = "compressed-tensors", specifier = "==0.17.0" }, { name = "datasets", marker = "extra == 'bench'" }, { name = "depyf", specifier = "==0.20.0" }, { name = "diskcache", specifier = "==5.6.3" }, { name = "einops" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, - { name = "fastsafetensors", specifier = ">=0.2.2" }, - { name = "fastsafetensors", marker = "extra == 'fastsafetensors'", specifier = ">=0.2.2" }, + { name = "fastsafetensors", specifier = ">=0.3.2" }, + { name = "fastsafetensors", marker = "extra == 'fastsafetensors'", specifier = ">=0.3.2" }, { name = "filelock", specifier = ">=3.16.1" }, - { name = "flashinfer-cubin", specifier = "==0.6.8.post1" }, - { name = "flashinfer-python", specifier = "==0.6.8.post1" }, + { name = "flashinfer-cubin", specifier = "==0.6.12" }, + { name = "flashinfer-python", specifier = "==0.6.12" }, { name = "gguf", specifier = ">=0.17.0" }, { name = "helion", marker = "extra == 'helion'", specifier = "==1.0.0" }, + { name = "humming-kernels", extras = ["cu12"], specifier = "==0.1.4" }, { name = "ijson" }, { name = "instanttensor", marker = "extra == 'instanttensor'", specifier = ">=0.1.5" }, { name = "lark", specifier = "==1.2.2" }, - { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'", specifier = ">=1.3.0,<1.4.0" }, + { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'", specifier = ">=1.7.0,<1.8.0" }, { name = "lm-format-enforcer", specifier = "==0.11.3" }, { name = "matplotlib", marker = "extra == 'bench'" }, { name = "mcp" }, { name = "mistral-common", extras = ["audio"], marker = "extra == 'audio'" }, - { name = "mistral-common", extras = ["image"], specifier = ">=1.11.2" }, + { name = "mistral-common", extras = ["image"], specifier = ">=1.11.3" }, { name = "model-hosting-container-standards", specifier = ">=0.1.14,<1.0.0" }, { name = "msgspec" }, { name = "ninja" }, { name = "numba", specifier = "==0.65.0" }, { name = "numpy" }, - { name = "nvidia-cudnn-frontend", specifier = ">=1.13.0,<1.19.0" }, - { name = "nvidia-cutlass-dsl", specifier = ">=4.4.2" }, + { name = "nvidia-cudnn-frontend", specifier = ">=1.19.1" }, + { name = "nvidia-cutlass-dsl", specifier = "==4.5.2" }, { name = "openai", specifier = ">=2.0.0" }, { name = "openai-harmony", specifier = ">=0.0.3" }, { name = "opencv-python-headless", specifier = ">=4.13.0" }, @@ -2577,6 +2723,7 @@ requires-dist = [ { name = "regex" }, { name = "requests", specifier = ">=2.26.0" }, { name = "runai-model-streamer", extras = ["azure", "gcs", "s3"], marker = "extra == 'runai'", specifier = ">=0.15.7" }, + { name = "safetensors", specifier = ">=0.6.2" }, { name = "scipy", marker = "extra == 'audio'" }, { name = "scipy", marker = "extra == 'bench'" }, { name = "seaborn", marker = "extra == 'bench'" }, @@ -2590,6 +2737,7 @@ requires-dist = [ { name = "tiktoken", specifier = ">=0.6.0" }, { name = "tilelang", specifier = "==0.1.9" }, { name = "tokenizers", specifier = ">=0.21.1" }, + { name = "tokenspeed-mla", specifier = "==0.1.2" }, { name = "torch", specifier = "==2.11.0" }, { name = "torchaudio", specifier = "==2.11.0" }, { name = "torchvision", specifier = "==0.26.0" }, @@ -2598,7 +2746,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.10" }, { name = "watchfiles" }, { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=0.2.0,<1.0.0" }, - { name = "zentorch-weekly", marker = "extra == 'zen'", specifier = "==5.2.1.dev20260408" }, + { name = "zentorch", marker = "extra == 'zen'", specifier = "==2.11.0.0" }, ] provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttensor", "runai", "audio", "video", "flashinfer", "helion", "grpc", "otel"]