diff --git a/.github/workflows/build-gpu-image.yml b/.github/workflows/build-gpu-image.yml index 12dbfad96..cdfc23634 100644 --- a/.github/workflows/build-gpu-image.yml +++ b/.github/workflows/build-gpu-image.yml @@ -30,6 +30,11 @@ on: required: true default: true type: boolean + prewarm_modal: + description: "Prebuild the pushed image in Modal when auth is configured" + required: true + default: true + type: boolean prewarm_timeout: description: "Timeout for GPU node prewarm rollout" required: true @@ -155,11 +160,16 @@ jobs: PULL_IMAGE_REPO: ${{ inputs.pull_image_repo || 'docker.io/bradhiltonnw/art-gpu' }} IMAGE_TAG: ${{ inputs.tag }} NO_CACHE: ${{ inputs.no_cache }} + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + PREWARM_MODAL_INPUT: ${{ inputs.prewarm_modal }} PREWARM_NODES: ${{ inputs.prewarm_nodes }} PREWARM_TIMEOUT: ${{ inputs.prewarm_timeout }} run: | IMAGE_TAG="${IMAGE_TAG:-latest}" NO_CACHE="${NO_CACHE:-false}" + export PREWARM_MODAL="${PREWARM_MODAL:-auto}" + PREWARM_MODAL_INPUT="${PREWARM_MODAL_INPUT:-true}" PREWARM_NODES="${PREWARM_NODES:-true}" PREWARM_TIMEOUT="${PREWARM_TIMEOUT:-30m}" @@ -175,6 +185,10 @@ jobs: args+=(--no-cache) fi + if [ "${PREWARM_MODAL_INPUT}" = "false" ]; then + args+=(--no-prewarm-modal) + fi + if [ "${PREWARM_NODES}" != "true" ]; then args+=(--no-prewarm-nodes) fi diff --git a/dev/trainer_rank.py b/dev/trainer_rank.py new file mode 100644 index 000000000..2b9ee70c3 --- /dev/null +++ b/dev/trainer_rank.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import os + +import torch +import torch.distributed as dist +from transformers import AutoTokenizer +import typer + +from art.megatron.trainer_rank import AdamParams, ForwardInput, TrainerRank + + +def main( + model: str = "Qwen/Qwen3-0.6B", + dataset: str = "roneneldan/TinyStories", + split: str = "train", + text_column: str = "text", + samples: int = 16, + steps: int = 1, + micro_batch_size: int = 1, + lr: float = 5e-5, + layers: int = 2, + max_seq_length: int = 256, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + if not torch.cuda.is_available(): + raise RuntimeError("dev/trainer_rank.py requires CUDA") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + + try: + from datasets import load_dataset + + from art.megatron import train as megatron_train + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + inputs: list[ForwardInput[torch.Tensor, None, None, None]] = [] + for row in load_dataset(dataset, split=split, streaming=True): + text = str(row.get(text_column, "")).strip() # type: ignore[union-attr] + if not text: + continue + token_ids = tokenizer( + text, + add_special_tokens=True, + truncation=True, + max_length=max_seq_length + 1, + return_tensors="pt", + )["input_ids"].reshape(-1) + if int(token_ids.numel()) <= 1: + continue + inputs.append( + ForwardInput( + input_tokens=token_ids[:-1], + target_tokens=token_ids[1:], + ) + ) + if len(inputs) >= samples: + break + if not inputs: + raise RuntimeError("dataset produced no tokenized training examples") + + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=lambda provider: setattr( + provider, + "num_layers", + layers, + ), + print_env=dist.get_rank() == 0, + ) + rank = TrainerRank(runtime, micro_batch_size=micro_batch_size) + if dist.get_rank() == 0: + print( + "TrainerRank ready: " + f"dp={megatron_train.ps.get_data_parallel_world_size()} " + f"device={rank.device}", + flush=True, + ) + + for step in range(steps): + loss_sum = torch.tensor(0.0, device=rank.device) + token_count = torch.tensor(0.0, device=rank.device) + for micro in rank.micro_batches(inputs): + outputs = rank.forward(micro.inputs) + loss = torch.tensor(0.0, device=rank.device) + for output in outputs: + assert output.target_logprobs is not None + loss = loss - output.target_logprobs.sum() + token_count += output.target_logprobs.numel() + if loss.requires_grad: + loss.backward() + loss_sum += loss.detach() + + rank.dp_reduce(loss_sum) + rank.dp_reduce(token_count) + scale = 1.0 / max(float(token_count.item()), 1.0) + metrics = rank.optim_step( + params=AdamParams(learning_rate=lr), + scale_grads=scale, + ) + metrics["loss"] = float(loss_sum.item() * scale) + metrics["tokens"] = float(token_count.item()) + if dist.get_rank() == 0: + print(f"step={step} {metrics}", flush=True) + + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_parity_probe.py b/dev/trainer_rank_parity_probe.py new file mode 100644 index 000000000..8e372fa75 --- /dev/null +++ b/dev/trainer_rank_parity_probe.py @@ -0,0 +1,532 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +import json +import os +import re +from typing import Any, cast + +import torch +import torch.distributed as dist +import typer + +from art.megatron.trainer_rank import ( + AnyForwardInput, + TrainerRank, + _language_model, + _pack_forward_items, + _PackedForwardBatch, +) + + +@dataclass(frozen=True) +class _Capture: + values: dict[str, torch.Tensor] + positions_by_item: tuple[torch.Tensor, ...] + source_positions_by_item: tuple[torch.Tensor, ...] + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + sequences: int = 6, + sequence_length: int = 7, + compare_requests: int = 6, + request_shape: str = "varied", + oracle: str = "independent", + max_depth: int = 1, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + torch.manual_seed(1234) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=lambda provider: setattr( + provider, + "num_layers", + layers, + ), + print_env=dist.get_rank() == 0, + ) + if int(ps.get_tensor_model_parallel_world_size()) != 1: + raise RuntimeError("trainer_rank_parity_probe currently expects TP=1") + for chunk in runtime.model: + chunk.eval() + + rank = TrainerRank(runtime, shared_prefix_max_depth=max_depth) + requests = _unique_requests( + sequences=sequences, + sequence_length=sequence_length, + request_shape=request_shape, + ) + request_count = min(compare_requests, len(requests)) + + with torch.no_grad(): + packed = _run_capture(rank, requests) + records = _records_from_capture( + kind="packed", + capture=packed, + request_indices=range(len(requests)), + cp_rank=int(ps.get_context_parallel_rank()), + dp_rank=int(ps.get_data_parallel_rank()), + ) + for request_index, request in enumerate(requests): + if oracle == "independent": + oracle_capture = _run_capture(rank, [request]) + oracle_request_indices = (request_index,) + oracle_local_indices = None + elif oracle == "same-layout": + oracle_capture = _run_capture( + rank, + requests, + mutate_except=request_index, + ) + oracle_request_indices = range(len(requests)) + oracle_local_indices = (request_index,) + else: + raise ValueError("oracle must be 'independent' or 'same-layout'") + records.extend( + _records_from_capture( + kind="independent", + capture=oracle_capture, + request_indices=oracle_request_indices, + cp_rank=int(ps.get_context_parallel_rank()), + dp_rank=int(ps.get_data_parallel_rank()), + local_indices=oracle_local_indices, + ) + ) + + gathered: list[list[dict[str, object]] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, records) + if dist.get_rank() == 0: + flat_records = [ + record for rank_records in gathered for record in rank_records or [] + ] + report = _build_report( + records=flat_records, + requests=requests[:request_count], + topology={ + "world": dist.get_world_size(), + "dp": int(ps.get_data_parallel_world_size()), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + }, + oracle=oracle, + ) + print(json.dumps(report, sort_keys=True), flush=True) + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _unique_requests( + *, + sequences: int, + sequence_length: int, + request_shape: str, +) -> list[AnyForwardInput]: + from art.megatron.trainer_rank import ForwardInput + + if sequences < 1 or sequence_length < 2: + raise ValueError("sequences must be >= 1 and sequence_length must be >= 2") + if request_shape == "varied": + base_rows = ( + (11, 12, 13, 14, 15, 16, 17), + (11, 12, 13, 14, 24, 25), + (11, 12, 13, 14, 24, 26), + (11, 12, 13, 27), + (31, 32, 33, 34), + (31, 32, 33, 35), + (11, 12, 13, 14, 15, 16, 17), + (41, 42, 43), + (41, 42, 44, 45), + (51, 52, 53, 54, 55), + (61, 62, 63), + (61, 62, 64, 65), + (71, 72), + (81, 82, 83, 84), + (91, 92, 93), + (101, 102, 103, 104, 105), + ) + return [ + ForwardInput( + input_tokens=torch.tensor(row, dtype=torch.long) + 1000 * index + ) + for index, row in enumerate(base_rows[:sequences]) + ] + if request_shape == "deep": + base_rows = ( + (11, 12, 13, 14, 15, 16, 17), + (11, 12, 13, 14, 15, 16, 18), + (11, 12, 13, 14, 15, 19), + (11, 12, 13, 14, 20), + (11, 12, 21), + (31, 32, 33, 34, 35), + (31, 32, 33, 34, 36), + (31, 32, 33, 37), + (41, 42, 43), + (41, 42, 44), + (51, 52, 53, 54), + (61, 62), + (71, 72, 73, 74, 75), + (71, 72, 73, 76), + (81,), + (91, 92, 93), + ) + return [ + ForwardInput(input_tokens=torch.tensor(row, dtype=torch.long)) + for row in base_rows[:sequences] + ] + if request_shape != "equal": + raise ValueError("request_shape must be 'equal', 'varied', or 'deep'") + return [ + ForwardInput( + input_tokens=torch.arange( + 1000 * index + 11, + 1000 * index + 11 + sequence_length, + dtype=torch.long, + ) + ) + for index in range(sequences) + ] + + +def _run_capture( + rank: TrainerRank, + requests: Sequence[AnyForwardInput], + *, + mutate_except: int | None = None, +) -> _Capture: + from art.megatron.train import _placeholder_attention_mask + + model = _language_model(rank.runtime.model[0]) + items = [rank._forward_item(request) for request in requests] + batch = _pack_forward_items(items, max_depth=rank.shared_prefix_max_depth) + if mutate_except is not None: + batch = _mutated_batch( + batch, keep_positions=batch.positions_by_item[mutate_except] + ) + prepared = rank._prepare_packed_forward(batch) + local_seq_len = int(prepared.tokens.shape[1]) + values: dict[str, torch.Tensor] = {} + handles = _register_hooks(model, values, seq_len=local_seq_len) + try: + handler = rank._handler() + forward_kwargs = handler.get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ) + extra_block_kwargs = cast( + dict[str, object] | None, + forward_kwargs.pop("extra_block_kwargs", None), + ) + preprocessed = model._preprocess( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + packed_seq_params=prepared.packed_seq_params, + ) + values["00.preprocess.decoder_input"] = _rows( + cast(torch.Tensor, preprocessed[0]).detach(), + seq_len=local_seq_len, + ) + hidden = cast( + torch.Tensor, + model.decoder( + hidden_states=preprocessed[0], + attention_mask=_placeholder_attention_mask(rank.device), + rotary_pos_emb=preprocessed[1], + rotary_pos_cos=preprocessed[2], + rotary_pos_sin=preprocessed[3], + rotary_pos_cos_sin=preprocessed[6] if len(preprocessed) == 7 else None, + packed_seq_params=prepared.packed_seq_params, + sequence_len_offset=preprocessed[4], + padding_mask=preprocessed[5], + **(extra_block_kwargs or {}), + ), + ) + gathered_hidden = rank._gather_sequence_parallel_hidden(hidden) + values["90.decoder.output"] = gathered_hidden.detach() + values["99.lm_head.logits"] = _logits(rank, gathered_hidden).detach() + return _Capture( + values=values, + positions_by_item=prepared.positions_by_item, + source_positions_by_item=prepared.source_positions_by_item, + ) + finally: + for handle in handles: + handle.remove() + + +def _mutated_batch( + batch: _PackedForwardBatch, + *, + keep_positions: torch.Tensor, +) -> _PackedForwardBatch: + tokens = batch.tokens.clone() + mask = torch.ones(int(tokens.shape[1]), dtype=torch.bool, device=tokens.device) + mask[keep_positions.to(device=tokens.device)] = False + replacement = ( + torch.arange(int(tokens.shape[1]), dtype=tokens.dtype, device=tokens.device) + + 50_000 + ) + tokens[0, mask] = replacement[mask] % 100_000 + return _PackedForwardBatch( + tokens=tokens, + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + position_ids=batch.position_ids, + positions_by_item=batch.positions_by_item, + ) + + +def _register_hooks( + model: torch.nn.Module, + values: dict[str, torch.Tensor], + *, + seq_len: int, +) -> list[Any]: + handles: list[Any] = [] + for module_name, module in model.named_modules(): + label = _capture_label(module_name) + if label is None: + continue + + def hook( + _module: torch.nn.Module, + _inputs: tuple[object, ...], + output: object, + *, + label: str = label, + ) -> None: + tensor = _first_tensor(output) + if tensor is not None: + try: + values[label] = _rows(tensor.detach(), seq_len=seq_len) + except RuntimeError: + pass + + handles.append(module.register_forward_hook(hook)) + return handles + + +def _capture_label(module_name: str) -> str | None: + layer_prefix = r"decoder\.layers\.(\d+)(?:\._orig_mod)?" + if re.fullmatch(r"decoder\.layers\.(\d+)\._orig_mod", module_name): + return None + layer_match = re.fullmatch(r"decoder\.layers\.(\d+)", module_name) + if layer_match: + return f"30.layer.{int(layer_match.group(1)):03d}.output" + input_norm_match = re.fullmatch(rf"{layer_prefix}\.input_layernorm", module_name) + if input_norm_match: + return f"05.layer.{int(input_norm_match.group(1)):03d}.input_layernorm" + qkv_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.linear_qkv", module_name + ) + if qkv_match: + return f"08.layer.{int(qkv_match.group(1)):03d}.self_attention.linear_qkv" + core_attention_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.core_attention", + module_name, + ) + if core_attention_match: + return f"10.layer.{int(core_attention_match.group(1)):03d}.self_attention.core_attention" + attention_proj_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.linear_proj", + module_name, + ) + if attention_proj_match: + return f"12.layer.{int(attention_proj_match.group(1)):03d}.self_attention.linear_proj" + attention_match = re.fullmatch( + rf"{layer_prefix}\.self_attention", + module_name, + ) + if attention_match: + return f"15.layer.{int(attention_match.group(1)):03d}.self_attention" + pre_mlp_norm_match = re.fullmatch( + rf"{layer_prefix}\.pre_mlp_layernorm", + module_name, + ) + if pre_mlp_norm_match: + return f"18.layer.{int(pre_mlp_norm_match.group(1)):03d}.pre_mlp_layernorm" + fc1_match = re.fullmatch(rf"{layer_prefix}\.mlp\.linear_fc1", module_name) + if fc1_match: + return f"20.layer.{int(fc1_match.group(1)):03d}.mlp.linear_fc1" + fc2_match = re.fullmatch(rf"{layer_prefix}\.mlp\.linear_fc2", module_name) + if fc2_match: + return f"22.layer.{int(fc2_match.group(1)):03d}.mlp.linear_fc2" + mlp_match = re.fullmatch(rf"{layer_prefix}\.mlp", module_name) + if mlp_match: + return f"25.layer.{int(mlp_match.group(1)):03d}.mlp" + if module_name == "decoder.final_layernorm": + return "80.decoder.final_layernorm" + return None + + +def _first_tensor(value: object) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, (tuple, list)): + for item in value: + tensor = _first_tensor(item) + if tensor is not None: + return tensor + return None + + +def _rows(tensor: torch.Tensor, *, seq_len: int) -> torch.Tensor: + if tensor.ndim >= 2 and int(tensor.shape[0]) == seq_len: + rows = tensor + if rows.ndim >= 3 and int(rows.shape[1]) == 1: + return rows[:, 0].contiguous() + return rows.contiguous() + if tensor.ndim >= 2 and int(tensor.shape[1]) == seq_len: + rows = ( + tensor[:, :, 0] + if tensor.ndim == 4 and int(tensor.shape[2]) == 1 + else tensor + ) + if int(rows.shape[0]) == 1: + return rows[0].contiguous() + raise RuntimeError( + f"Cannot identify sequence axis for tensor shape={tuple(tensor.shape)} " + f"seq_len={seq_len}" + ) + + +def _logits(rank: TrainerRank, hidden_rows: torch.Tensor) -> torch.Tensor: + model = _language_model(rank.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + if int(hidden_rows.shape[0]) == 0: + return hidden_rows.new_empty((0, int(model.vocab_size))) + return rank._logits_from_hidden_rows( + model, + hidden_rows, + output_weight=output_weight, + ) + + +def _records_from_capture( + *, + kind: str, + capture: _Capture, + request_indices: Sequence[int], + cp_rank: int, + dp_rank: int, + local_indices: Sequence[int] | None = None, +) -> list[dict[str, object]]: + records: list[dict[str, object]] = [] + local_index_set = None if local_indices is None else frozenset(local_indices) + for local_index, request_index in enumerate(request_indices): + if local_index_set is not None and local_index not in local_index_set: + continue + positions = capture.positions_by_item[local_index] + source_positions = capture.source_positions_by_item[local_index] + if int(positions.numel()) == 0: + continue + for name, rows in capture.values.items(): + records.append( + { + "kind": kind, + "name": name, + "request_index": int(request_index), + "source_positions": source_positions.cpu(), + "value": rows.index_select(0, positions.to(rows.device)).cpu(), + "cp": int(cp_rank), + "dp": int(dp_rank), + } + ) + return records + + +def _build_report( + *, + records: list[dict[str, object]], + requests: Sequence[AnyForwardInput], + topology: dict[str, int], + oracle: str, +) -> dict[str, object]: + results = [] + names = sorted( + { + cast(str, record["name"]) + for record in records + if record.get("kind") == "packed" + } + ) + for request_index, request in enumerate(requests): + length = int(request.input_tokens.numel()) + for name in names: + packed = _assemble(records, "packed", name, request_index, length) + independent = _assemble(records, "independent", name, request_index, length) + if packed is None or independent is None: + continue + diff = (packed.float() - independent.float()).abs() + denom = independent.float().abs().max().clamp_min(1e-12) + results.append( + { + "request": request_index, + "site": name, + "shape": list(packed.shape), + "max_abs": float(diff.max().item()) if int(diff.numel()) else 0.0, + "mean_abs": float(diff.mean().item()) if int(diff.numel()) else 0.0, + "rel_max": float((diff.max() / denom).item()) + if int(diff.numel()) + else 0.0, + } + ) + return { + "topology": topology, + "oracle": oracle, + "requests": len(requests), + "results": results, + } + + +def _assemble( + records: list[dict[str, object]], + kind: str, + name: str, + request_index: int, + length: int, +) -> torch.Tensor | None: + matching = [ + record + for record in records + if record["kind"] == kind + and record["name"] == name + and record["request_index"] == request_index + ] + if not matching: + return None + first = cast(torch.Tensor, matching[0]["value"]) + output = torch.empty((length, *first.shape[1:]), dtype=first.dtype) + filled = torch.zeros(length, dtype=torch.bool) + for record in matching: + positions = cast(torch.Tensor, record["source_positions"]) + value = cast(torch.Tensor, record["value"]) + output[positions] = value + filled[positions] = True + if not bool(filled.all().item()): + raise RuntimeError( + f"Missing positions for {kind} {name} request={request_index}" + ) + return output + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_perf.py b/dev/trainer_rank_perf.py new file mode 100644 index 000000000..cabf3c93e --- /dev/null +++ b/dev/trainer_rank_perf.py @@ -0,0 +1,2011 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from contextlib import suppress +import json +import os +from pathlib import Path +import threading +import time +from typing import Any + +import torch +import torch.distributed as dist +import typer + +from art.megatron.trainer_rank import ( + AdamParams, + ForwardInput, + TopK, + TrainerRank, + _batch_seq_logits, + _language_model, + _pack_forward_items, +) + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + seq_len: int = 2048, + prefix_families: int = 0, + prefix_len: int = 5000, + mid_prefixes_per_family: int = 1, + mid_prefix_len: int = 0, + branches_per_prefix: int = 16, + completion_len: int = 100, + warmup: int = 2, + repeat: int = 5, + head_chunk_tokens: int = 512, + shared_prefix_max_depth: int = 1, + benchmark: str = "target_builtin_fwd", + target_count: int = 4, + top_k: int = 5, + top_k_values: str = "1,2,5,10,20,50", + max_unpacked_output_gb: float = 0.5, + mask_prefix_targets: bool = True, + workload: str = "regular", + tree_depth: int = 3, + tree_seed: int = 1, + tree_duplicate_factor: int = 1, + adapter_slots: int = 0, + adapter_slot_mode: str = "family", + adapter_slot_rank: int = 1, + learning_rate: float = 1e-5, + full_step_offload_reload: bool = False, + memory_sample_interval_s: float = 0.05, + compare_target_correctness: bool = False, + run_adapter_sanity: bool = False, + output_jsonl: str = "", +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + provider_configure = ( + (lambda provider: setattr(provider, "num_layers", layers)) + if layers > 0 + else None + ) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=provider_configure, + print_env=dist.get_rank() == 0, + ) + for chunk in runtime.model: + chunk.eval() + rank = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_tokens, + shared_prefix_max_depth=shared_prefix_max_depth, + ) + if adapter_slots < 0: + raise ValueError("adapter_slots must be >= 0") + if adapter_slot_rank < 1: + raise ValueError("adapter_slot_rank must be >= 1") + if adapter_slots: + loaded_sites = _load_adapter_slots( + rank, + count=adapter_slots, + slot_rank=adapter_slot_rank, + ) + else: + loaded_sites = 0 + hidden_size, vocab_size, dtype_size = _runtime_output_shape(runtime) + model_config = getattr(_language_model(runtime.model[0]), "config", None) + + benchmarks = { + name.strip().replace("-", "_") + for name in benchmark.split(",") + if name.strip() + } + if "all" in benchmarks: + benchmarks = { + "target_builtin_fwd", + "target_trainer_fwd", + "target_hidden_fwd", + "logits_builtin_fwd", + "logits_hidden_fwd", + "target_builtin_fwd_bwd", + "target_builtin_masked_fwd_bwd", + "target_trainer_fwd_bwd", + "target_hidden_fwd_bwd", + "target_builtin_train_step", + "target_trainer_train_step", + "target_hidden_train_step", + "trainer_multi_target_fwd_bwd", + "trainer_multi_target_train_step", + "trainer_target", + "trainer_multi_target", + "trainer_topk", + "trainer_topk_head", + "trainer_topk_fwd_bwd", + "trainer_topk_train_step", + "trainer_topk_sweep", + "trainer_target_topk", + "trainer_hidden", + "trainer_all_no_logits", + "trainer_logits", + } + if "trainer_all" in benchmarks: + benchmarks.update( + { + "trainer_target", + "trainer_multi_target", + "trainer_multi_target_fwd_bwd", + "trainer_multi_target_train_step", + "trainer_topk", + "trainer_topk_head", + "trainer_topk_fwd_bwd", + "trainer_topk_train_step", + "trainer_topk_sweep", + "trainer_target_topk", + "trainer_hidden", + "trainer_all_no_logits", + "trainer_logits", + } + ) + + if target_count < 1: + raise ValueError("target_count must be >= 1") + if top_k < 1: + raise ValueError("top_k must be >= 1") + if memory_sample_interval_s < 0: + raise ValueError("memory_sample_interval_s must be >= 0") + requests, multi_target_requests, request_metadata = _requests( + seq_len=seq_len, + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + target_count=target_count, + mask_prefix_targets=mask_prefix_targets, + workload=workload, + tree_depth=tree_depth, + tree_seed=tree_seed, + tree_duplicate_factor=tree_duplicate_factor, + ) + requests = _route_adapter_slots( + requests, + adapter_slots=adapter_slots, + mode=adapter_slot_mode, + ) + multi_target_requests = _route_adapter_slots( + multi_target_requests, + adapter_slots=adapter_slots, + mode=adapter_slot_mode, + ) + stats_items = [rank._forward_item(request) for request in requests] + stats_batch = _pack_forward_items( + stats_items, + max_depth=rank.shared_prefix_max_depth, + ) + stats_prepared = rank._prepare_packed_forward(stats_batch) + request_stats = _packed_request_stats( + requests, + stats_items, + stats_batch, + request_metadata=request_metadata, + ) + planner_metadata = _gather_planner_metadata(stats_prepared) + target_items = None + target_prepared = None + if any(name.startswith("target_") for name in benchmarks): + target_items = stats_items + target_prepared = stats_prepared + logits_items = None + logits_prepared = None + if any(name.startswith("logits_") for name in benchmarks): + logits_items = [ + rank._forward_item(_with_outputs(request, logits=True)) + for request in requests + ] + logits_prepared = rank._prepare_packed_forward( + _pack_forward_items( + logits_items, + max_depth=rank.shared_prefix_max_depth, + ) + ) + results: dict[str, float] = {} + metadata: dict[str, object] = {} + rate_units: dict[str, dict[str, int]] = {} + + def register_case( + name: str, + case_requests: Sequence[ + ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, + ] + ], + case_stats: dict[str, int | str], + ) -> None: + units = _rate_units( + case_requests, + case_stats, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + rate_units[name] = units + for key, value in units.items(): + metadata[f"{name}_{key}"] = value + + for name in ( + "target_builtin_fwd", + "target_hidden_fwd", + "target_trainer_fwd", + "target_builtin_fwd_bwd", + "target_builtin_masked_fwd_bwd", + "target_trainer_fwd_bwd", + "target_hidden_fwd_bwd", + "target_builtin_train_step", + "target_trainer_train_step", + "target_hidden_train_step", + ): + register_case(name, requests, request_stats) + + memory_tracker = _CudaMemoryTracker( + device_index=int(os.environ["LOCAL_RANK"]), + sample_interval_s=memory_sample_interval_s, + ) + memory_tracker.start() + torch.cuda.reset_peak_memory_stats() + with torch.no_grad(): + if "target_builtin_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_builtin_fwd_ms"] = _bench( + lambda: _builtin( + rank, + target_prepared, + _packed_labels(target_items, target_prepared), + ), + warmup=warmup, + repeat=repeat, + ) + if "target_hidden_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_hidden_fwd_ms"] = _bench( + lambda: rank._project_head( + target_items, + target_prepared, + rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(target_prepared) + ), + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_trainer_fwd_ms"] = _bench( + lambda: rank._forward_packed(target_items, target_prepared), + warmup=warmup, + repeat=repeat, + ) + if "logits_builtin_fwd" in benchmarks: + assert logits_prepared is not None + register_case( + "logits_builtin_fwd", _logits_requests(requests), request_stats + ) + results["logits_builtin_fwd_ms"] = _bench( + lambda: _full_logits(rank, logits_prepared), + warmup=warmup, + repeat=repeat, + ) + if "logits_hidden_fwd" in benchmarks: + assert logits_items is not None and logits_prepared is not None + register_case( + "logits_hidden_fwd", _logits_requests(requests), request_stats + ) + results["logits_hidden_fwd_ms"] = _bench( + lambda: rank._project_head( + logits_items, + logits_prepared, + rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(logits_prepared) + ), + ), + warmup=warmup, + repeat=repeat, + ) + trainer_cases = { + "trainer_target": requests, + "trainer_multi_target": multi_target_requests, + "trainer_topk": [ + _with_outputs(request, top_k=top_k) for request in requests + ], + "trainer_target_topk": [ + _with_outputs( + request, + target_tokens=request.target_tokens, + top_k=top_k, + ) + for request in requests + ], + "trainer_hidden": [ + _with_outputs(request, hidden_states=True) for request in requests + ], + "trainer_all_no_logits": [ + _with_outputs( + request, + target_tokens=multi_request.target_tokens, + top_k=top_k, + hidden_states=True, + ) + for request, multi_request in zip( + requests, multi_target_requests, strict=True + ) + ], + "trainer_logits": [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ], + } + if "trainer_topk_sweep" in benchmarks: + for k in _int_values(top_k_values): + trainer_cases[f"trainer_topk_{k}"] = [ + _with_outputs(request, top_k=k) for request in requests + ] + for name, case_requests in trainer_cases.items(): + if name not in benchmarks and not ( + "trainer_topk_sweep" in benchmarks + and name.startswith("trainer_topk_") + ): + continue + output_gb = _request_output_gb( + case_requests, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + metadata[f"{name}_output_gb"] = round(output_gb, 3) + if max_unpacked_output_gb > 0 and output_gb > max_unpacked_output_gb: + metadata[f"{name}_skipped"] = "unpacked_output_cap" + continue + items = [rank._forward_item(request) for request in case_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + name, + case_requests, + _packed_request_stats( + case_requests, items, batch, request_metadata={} + ), + ) + prepared = rank._prepare_packed_forward(batch) + if adapter_slots: + results[f"{name}_ms"] = _bench( + lambda case_requests=case_requests: rank.forward(case_requests), + warmup=warmup, + repeat=repeat, + ) + else: + results[f"{name}_ms"] = _bench( + lambda items=items, prepared=prepared: rank._forward_packed( + items, + prepared, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_topk_head" in benchmarks: + case_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + output_gb = _request_output_gb( + case_requests, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + metadata["trainer_topk_head_output_gb"] = round(output_gb, 3) + items = [rank._forward_item(request) for request in case_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_head", + case_requests, + _packed_request_stats( + case_requests, items, batch, request_metadata={} + ), + ) + prepared = rank._prepare_packed_forward(batch) + hidden = rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(prepared) + ) + results["trainer_topk_head_ms"] = _bench( + lambda: rank._project_head(items, prepared, hidden), + warmup=warmup, + repeat=repeat, + ) + + if "target_builtin_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_fwd_bwd_ms"] = _bench( + lambda: _target_builtin_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_builtin_masked_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_masked_fwd_bwd_ms"] = _bench( + lambda: _target_builtin_masked_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_trainer_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_trainer_fwd_bwd_ms"] = _bench( + lambda: ( + _target_requests_loss(rank, requests) + if adapter_slots + else _target_trainer_loss( + rank, + target_items, + target_prepared, + ) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_hidden_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_hidden_fwd_bwd_ms"] = _bench( + lambda: _target_hidden_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + train_step_params = AdamParams(learning_rate=learning_rate) + offload_manager = ( + _make_offload_manager(runtime) if full_step_offload_reload else None + ) + if "target_builtin_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: _target_builtin_loss(rank, target_items, target_prepared), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_trainer_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _target_requests_loss(rank, requests) + if adapter_slots + else _target_trainer_loss(rank, target_items, target_prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "target_hidden_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_hidden_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: _target_hidden_loss(rank, target_items, target_prepared), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_multi_target_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_multi_target_fwd_bwd", + multi_target_requests, + _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_multi_target_fwd_bwd_ms"] = _bench( + lambda: ( + _target_requests_loss(rank, multi_target_requests) + if adapter_slots + else _target_trainer_loss(rank, items, prepared) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "trainer_multi_target_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_multi_target_train_step", + multi_target_requests, + _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_multi_target_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _target_requests_loss(rank, multi_target_requests) + if adapter_slots + else _target_trainer_loss(rank, items, prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_topk_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_fwd_bwd", + topk_requests, + _packed_request_stats(topk_requests, items, batch, request_metadata={}), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_topk_fwd_bwd_ms"] = _bench( + lambda: ( + _topk_requests_loss(rank, topk_requests) + if adapter_slots + else _trainer_topk_loss(rank, items, prepared) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "trainer_topk_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_train_step", + topk_requests, + _packed_request_stats(topk_requests, items, batch, request_metadata={}), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_topk_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _topk_requests_loss(rank, topk_requests) + if adapter_slots + else _trainer_topk_loss(rank, items, prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + + if compare_target_correctness and adapter_slots: + metadata["target_correctness_skipped"] = "adapter_slots" + elif compare_target_correctness: + assert target_items is not None and target_prepared is not None + metadata.update( + _target_correctness_metrics(rank, target_items, target_prepared) + ) + if run_adapter_sanity and adapter_slots > 0: + metadata.update( + _adapter_sanity_metrics( + rank, + requests, + params=train_step_params, + adapter_slots=adapter_slots, + ) + ) + + memory_tracker.stop() + memory_metadata = _distributed_memory_metadata(memory_tracker) + + if dist.get_rank() == 0: + token_rates = _rate_metrics(results, rate_units) + payload = { + "world": dist.get_world_size(), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "seq_len": seq_len, + "prefix_families": prefix_families, + "prefix_len": prefix_len, + "mid_prefixes_per_family": mid_prefixes_per_family, + "mid_prefix_len": mid_prefix_len, + "branches_per_prefix": branches_per_prefix, + "completion_len": completion_len, + "head_chunk_tokens": head_chunk_tokens, + "shared_prefix_max_depth": shared_prefix_max_depth, + "warmup": warmup, + "repeat": repeat, + "target_count": target_count, + "top_k": top_k, + "top_k_values": top_k_values, + "max_unpacked_output_gb": max_unpacked_output_gb, + "mask_prefix_targets": mask_prefix_targets, + "workload": workload, + "tree_depth": tree_depth, + "tree_seed": tree_seed, + "tree_duplicate_factor": tree_duplicate_factor, + "adapter_slots": adapter_slots, + "adapter_slot_mode": adapter_slot_mode, + "adapter_slot_rank": adapter_slot_rank, + "adapter_loaded_sites": loaded_sites, + "learning_rate": learning_rate, + "full_step_offload_reload": full_step_offload_reload, + "mtp_num_layers": getattr(model_config, "mtp_num_layers", None), + "cross_entropy_loss_fusion": getattr( + model_config, "cross_entropy_loss_fusion", None + ), + "cross_entropy_fusion_impl": getattr( + model_config, "cross_entropy_fusion_impl", None + ), + **_model_metadata(runtime, model, layers=layers), + **request_stats, + **memory_metadata, + **results, + **token_rates, + **metadata, + **planner_metadata, + } + line = json.dumps(payload, sort_keys=True) + print(line, flush=True) + if output_jsonl: + output_path = Path(output_jsonl) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("a", encoding="utf-8") as output_file: + output_file.write(line + "\n") + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _requests( + *, + seq_len: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, + target_count: int, + mask_prefix_targets: bool, + workload: str, + tree_depth: int, + tree_seed: int, + tree_duplicate_factor: int, +) -> tuple[ + list[ForwardInput[torch.Tensor, None, None, None]], + list[ForwardInput[torch.Tensor, None, None, None]], + dict[str, int | str], +]: + if workload == "regular" and prefix_families <= 0: + tokens = torch.arange(seq_len, dtype=torch.long) % 32_000 + 100 + labels = _labels(tokens, target_count=1) + return ( + [ForwardInput(input_tokens=tokens, target_tokens=labels)], + [ + ForwardInput( + input_tokens=tokens, + target_tokens=_labels(tokens, target_count=target_count), + ) + ], + { + "request_count": 1, + "workload_shape": "single", + }, + ) + + if prefix_len < 1 or branches_per_prefix < 1 or completion_len < 1: + raise ValueError( + "prefix_len, branches_per_prefix, and completion_len must be >= 1" + ) + if mid_prefixes_per_family < 1 or mid_prefix_len < 0: + raise ValueError("mid_prefixes_per_family must be >= 1 and mid_prefix_len >= 0") + + sequences, prefix_lengths, workload_shape = _workload_sequences( + workload=workload, + seq_len=seq_len, + prefix_families=max(prefix_families, 1), + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + tree_depth=tree_depth, + tree_seed=tree_seed, + tree_duplicate_factor=tree_duplicate_factor, + ) + requests = [] + multi_requests = [] + for tokens, shared_length in zip(sequences, prefix_lengths, strict=True): + labels = _labels(tokens, target_count=1) + multi_labels = _labels(tokens, target_count=target_count) + if mask_prefix_targets and shared_length: + labels[:shared_length] = -100 + multi_labels[:shared_length] = -100 + requests.append(ForwardInput(input_tokens=tokens, target_tokens=labels)) + multi_requests.append( + ForwardInput(input_tokens=tokens, target_tokens=multi_labels) + ) + + return ( + requests, + multi_requests, + { + "request_count": len(requests), + "workload_shape": workload_shape, + }, + ) + + +def _load_adapter_slots( + rank: TrainerRank, + *, + count: int, + slot_rank: int, +) -> int: + loaded_sites = 0 + for slot_index in range(count): + loaded_sites += rank.load_checkpoint_slot( + f"S{slot_index}", + _synthetic_adapter( + rank.runtime.model, slot_rank=slot_rank, seed=slot_index + ), + ) + return loaded_sites + + +def _synthetic_adapter( + model: Sequence[torch.nn.Module], + *, + slot_rank: int, + seed: int, +) -> dict[str, torch.Tensor]: + from art.megatron.lora import LoRA + + adapter: dict[str, torch.Tensor] = {} + generator = torch.Generator(device="cuda").manual_seed(10_000 + seed) + for chunk in model: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + a_keys = module._expected_weight_keys("lora_A") + b_keys = module._expected_weight_keys("lora_B") + for a_key, b_key in zip(a_keys, b_keys, strict=True): + adapter[a_key] = ( + torch.randn( + slot_rank, + module.in_features, + dtype=module.A_T.dtype, + device=module.A_T.device, + generator=generator, + ) + * 0.01 + ) + adapter[b_key] = ( + torch.randn( + module.out_features, + slot_rank, + dtype=module.B_T.dtype, + device=module.B_T.device, + generator=generator, + ) + * 0.01 + ) + if not adapter: + raise RuntimeError("adapter slot stress requested, but model has no LoRA sites") + return adapter + + +def _route_adapter_slots( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + adapter_slots: int, + mode: str, +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if adapter_slots == 0: + return list(requests) + if mode not in {"family", "round_robin", "single", "skewed_random"}: + raise ValueError( + "adapter_slot_mode must be one of: family, round_robin, single, " + "skewed_random" + ) + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + top_k=request.top_k, + logits=request.logits, + hidden_states=request.hidden_states, + checkpoint=f"S{_adapter_slot_index(index, request, adapter_slots, mode)}", + ) + for index, request in enumerate(requests) + ] + + +def _adapter_slot_index( + index: int, + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + adapter_slots: int, + mode: str, +) -> int: + if mode == "single": + return 0 + if mode == "round_robin": + return index % adapter_slots + if mode == "skewed_random": + bucket = (index * 1103515245 + 12345) & 0x7FFFFFFF + skew = bucket % 100 + if skew < 50: + return 0 + if skew < 75: + return min(1, adapter_slots - 1) + if skew < 90: + return min(2, adapter_slots - 1) + return min(3 + (bucket % max(1, adapter_slots - 3)), adapter_slots - 1) + first_token = ( + int(request.input_tokens[0].item()) if request.input_tokens.numel() else 0 + ) + return (first_token // 10_000_019) % adapter_slots + + +def _with_outputs( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, +) -> ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None +]: + return ForwardInput( + input_tokens=request.input_tokens, + target_tokens=target_tokens, + top_k=top_k, + logits=logits, + hidden_states=hidden_states, + checkpoint=request.checkpoint, + lora=request.lora, + ) + + +def _workload_sequences( + *, + workload: str, + seq_len: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, + tree_depth: int, + tree_seed: int, + tree_duplicate_factor: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + if workload in {"austin_198k", "austin_5k_16x100"}: + return _regular_tree_sequences( + prefix_families=30, + prefix_len=5000, + mid_prefixes_per_family=1, + mid_prefix_len=0, + branches_per_prefix=16, + completion_len=100, + ) + if workload == "austin_varied": + return _austin_varied_sequences() + if workload == "regular": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "single": + tokens = torch.arange(seq_len, dtype=torch.long) % 32_000 + 100 + return (tokens,), (0,), "single" + if workload == "long_root": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=1, + mid_prefix_len=0, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "long_mid": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "many_tiny_leaves": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(1, mid_prefixes_per_family), + mid_prefix_len=max(0, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=max(1, completion_len), + ) + if workload == "uneven": + return _uneven_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "duplicates": + sequences, shared, shape = _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + factor = max(1, tree_duplicate_factor) + return ( + tuple(sequence for sequence in sequences for _ in range(factor)), + tuple(length for length in shared for _ in range(factor)), + f"{shape}:duplicates={factor}", + ) + if workload == "random": + return _random_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + branches_per_prefix=max(2, min(branches_per_prefix, 4)), + completion_len=completion_len, + tree_depth=max(1, tree_depth), + seed=tree_seed, + ) + raise ValueError( + "workload must be one of: regular, single, long_root, long_mid, " + "many_tiny_leaves, uneven, duplicates, random, austin_198k, austin_varied" + ) + + +def _regular_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + nested = mid_prefixes_per_family > 1 and mid_prefix_len > 0 + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(prefix_families): + family_base = family * 10_000_019 + root = _tokens(family_base, prefix_len) + mid_count = mid_prefixes_per_family if nested else 1 + for mid in range(mid_count): + mid_prefix = ( + _tokens(family_base + 1_000_003 + mid * 100_003, mid_prefix_len) + if nested + else torch.empty(0, dtype=torch.long) + ) + shared = torch.cat((root, mid_prefix)) + for branch in range(branches_per_prefix): + sequences.append( + torch.cat( + ( + shared, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + shared_lengths.append(int(shared.numel())) + shape = ( + f"families={prefix_families}:mid={mid_prefixes_per_family}:" + f"branches={branches_per_prefix}:nested={int(nested)}" + ) + return tuple(sequences), tuple(shared_lengths), shape + + +def _austin_varied_sequences() -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(30): + family_base = family * 10_000_019 + prefix_len = 4500 + ((family * 137) % 1001) + root = _tokens(family_base, prefix_len) + branch_count = 10 + ((family * 7) % 13) + for branch in range(branch_count): + completion_len = 32 + ((family * 19 + branch * 23) % 145) + sequences.append( + torch.cat( + ( + root, + _tokens( + family_base + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + shared_lengths.append(int(root.numel())) + return tuple(sequences), tuple(shared_lengths), "austin_varied" + + +def _uneven_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(prefix_families): + family_base = family * 10_000_019 + root_len = max(1, prefix_len // (family + 1)) + root = _tokens(family_base, root_len) + for mid in range(mid_prefixes_per_family): + mid_len = max(1, mid_prefix_len // (mid + 1)) + mid_prefix = _tokens(family_base + 1_000_003 + mid * 100_003, mid_len) + branch_count = max(1, branches_per_prefix - mid) + for branch in range(branch_count): + leaf_len = max(1, completion_len * (branch + 1) // branch_count) + shared = torch.cat((root, mid_prefix)) + sequences.append( + torch.cat( + ( + shared, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + leaf_len, + ), + ) + ) + ) + shared_lengths.append(int(shared.numel())) + return tuple(sequences), tuple(shared_lengths), "uneven" + + +def _random_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + branches_per_prefix: int, + completion_len: int, + tree_depth: int, + seed: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + generator = torch.Generator().manual_seed(seed) + next_offset = 1 + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def segment(length: int) -> torch.Tensor: + nonlocal next_offset + out = _tokens(next_offset, max(1, length)) + next_offset += max(1, length) + 10_000 + return out + + def length_for_depth(depth: int) -> int: + if depth == 0: + return max(1, prefix_len) + choices = (1, 8, 64, max(1, completion_len), max(1, prefix_len // 2)) + return choices[randint(0, len(choices) - 1)] + + def walk(prefix: torch.Tensor, depth: int) -> None: + shared = torch.cat((prefix, segment(length_for_depth(depth)))) + if depth + 1 >= tree_depth: + leaf_count = randint(2, branches_per_prefix) + for _ in range(leaf_count): + leaf = segment(randint(1, max(1, completion_len))) + sequences.append(torch.cat((shared, leaf))) + shared_lengths.append(int(shared.numel())) + return + for _ in range(randint(2, branches_per_prefix)): + walk(shared, depth + 1) + + for _ in range(prefix_families): + walk(torch.empty(0, dtype=torch.long), 0) + return tuple(sequences), tuple(shared_lengths), f"random:depth={tree_depth}" + + +def _packed_request_stats( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + items: Sequence[object], + batch: object, + *, + request_metadata: dict[str, int | str], +) -> dict[str, int | str]: + from art.megatron.shared_prefix_tree import max_shared_prefix_tree_depth + + trainable_mask = torch.zeros(int(batch.tokens.numel()), dtype=torch.bool) + trainable_tokens = 0 + for item, positions in zip(items, batch.positions_by_item, strict=True): + labels = getattr(item, "labels", None) + if labels is None: + continue + mask = labels != -100 + row_mask = mask.reshape(int(mask.shape[0]), -1).any(dim=1) + trainable_tokens += int(mask.sum().item()) + trainable_mask[positions.reshape(-1).cpu()] |= row_mask.cpu() + group_ids = batch.group_ids + parent_ids = batch.parent_ids + return { + **request_metadata, + "request_count": len(requests), + "packed_tokens": int(batch.tokens.numel()), + "logical_tokens": sum( + int(request.input_tokens.numel()) for request in requests + ), + "trainable_tokens": trainable_tokens, + "packed_trainable_tokens": int(trainable_mask.sum().item()), + "packed_group_count": int(group_ids.max().item()) + if int(group_ids.numel()) + else 0, + "nested_prefix_depth": max_shared_prefix_tree_depth( + group_ids=group_ids, + parent_ids=parent_ids, + ), + } + + +def _gather_planner_metadata(prepared: object) -> dict[str, object]: + local = _local_planner_metadata(prepared) + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + if dist.get_rank() != 0: + return {} + ranks = [metrics or {} for metrics in gathered] + gdn_tokens = [int(metrics.get("gdn_tokens", 0)) for metrics in ranks] + attention_tokens = [int(metrics.get("attention_tokens", 0)) for metrics in ranks] + keys = ( + "tree_local_bucket_count", + "tree_chain_bucket_count", + "tree_local_segment_count", + "tree_chain_segment_count", + "tree_local_real_tokens", + "tree_chain_real_tokens", + "tree_state_transfer_count", + "tree_state_transfer_rows", + "tree_max_padding_ratio", + ) + merged: dict[str, object] = { + "planner_rank_gdn_tokens": gdn_tokens, + "planner_rank_attention_tokens": attention_tokens, + "planner_gdn_token_imbalance": max(gdn_tokens, default=0) + - min(gdn_tokens, default=0), + } + for key in keys: + values = [metrics[key] for metrics in ranks if key in metrics] + if not values: + continue + if key.endswith("_ratio"): + merged[f"planner_{key}_max"] = round( + max(float(value) for value in values), 3 + ) + else: + merged[f"planner_{key}_sum"] = int(sum(int(value) for value in values)) + merged[f"planner_{key}_max"] = int(max(int(value) for value in values)) + rank0 = ranks[0] if ranks else {} + for key in ("tree_depth_count", "tree_family_count", "tree_completion_count"): + if key in rank0: + merged[f"planner_{key}"] = rank0[key] + return merged + + +def _local_planner_metadata(prepared: object) -> dict[str, object]: + plan = getattr( + getattr(prepared, "attention_state", None), "gdn_execution_plan", None + ) + if plan is None: + return {} + local_buckets = tuple( + bucket + for depth in getattr(plan, "tree_segment_buckets_by_depth", ()) + for bucket in depth + ) + chain_buckets = tuple( + bucket + for depth in getattr(plan, "tree_chain_buckets_by_depth", ()) + for bucket in depth + ) + all_buckets = (*local_buckets, *chain_buckets) + padding_ratios = [ + bucket.length * bucket.segment_count / max(1, bucket.real_token_count) + for bucket in all_buckets + ] + transfers_by_depth = getattr(plan, "tree_state_transfers_by_depth", ()) + return { + "attention_tokens": int(getattr(plan, "attention_token_count", 0)), + "gdn_tokens": int(getattr(plan, "gdn_token_count", 0)), + "tree_depth_count": len(getattr(plan, "tree_segment_buckets_by_depth", ())), + "tree_family_count": int(getattr(plan, "family_count", 0)), + "tree_completion_count": int(getattr(plan, "completion_count", 0)), + "tree_local_bucket_count": len(local_buckets), + "tree_chain_bucket_count": len(chain_buckets), + "tree_local_segment_count": sum( + bucket.segment_count for bucket in local_buckets + ), + "tree_chain_segment_count": sum( + bucket.segment_count for bucket in chain_buckets + ), + "tree_local_real_tokens": sum( + bucket.real_token_count for bucket in local_buckets + ), + "tree_chain_real_tokens": sum( + bucket.real_token_count for bucket in chain_buckets + ), + "tree_state_transfer_count": sum( + len(transfers) for transfers in transfers_by_depth + ), + "tree_state_transfer_rows": sum( + len(transfer.family_indices) + for transfers in transfers_by_depth + for transfer in transfers + ), + "tree_max_padding_ratio": max(padding_ratios, default=1.0), + } + + +def _tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _int_values(value: str) -> list[int]: + values = [int(part) for part in value.split(",") if part.strip()] + if not values or any(item < 1 for item in values): + raise ValueError("top_k_values must contain positive integers") + return values + + +def _labels(tokens: torch.Tensor, *, target_count: int) -> torch.Tensor: + labels = torch.stack( + [((tokens * 7 + 3 + index) % 32_000) for index in range(target_count)], + dim=1, + ) + if target_count > 1: + labels[::17, -1] = -100 + return labels + return labels[:, 0] + + +class _CudaMemoryTracker: + def __init__(self, *, device_index: int, sample_interval_s: float) -> None: + self.device_index = device_index + self.sample_interval_s = sample_interval_s + self.process_peak_bytes = 0 + self.allocated_peak_bytes = 0 + self.reserved_peak_bytes = 0 + self._stop = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + if not torch.cuda.is_available(): + return + torch.cuda.reset_peak_memory_stats() + self._sample() + if self.sample_interval_s <= 0: + return + self._thread = threading.Thread(target=self._poll, daemon=True) + self._thread.start() + + def stop(self) -> None: + if not torch.cuda.is_available(): + return + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=1.0) + torch.cuda.synchronize() + self._sample() + self.allocated_peak_bytes = max( + self.allocated_peak_bytes, + int(torch.cuda.max_memory_allocated()), + ) + self.reserved_peak_bytes = max( + self.reserved_peak_bytes, + int(torch.cuda.max_memory_reserved()), + ) + + def _poll(self) -> None: + while not self._stop.wait(self.sample_interval_s): + self._sample() + + def _sample(self) -> None: + self.process_peak_bytes = max( + self.process_peak_bytes, + _current_process_gpu_memory_bytes(self.device_index), + ) + self.allocated_peak_bytes = max( + self.allocated_peak_bytes, + int(torch.cuda.memory_allocated()) if torch.cuda.is_available() else 0, + ) + self.reserved_peak_bytes = max( + self.reserved_peak_bytes, + int(torch.cuda.memory_reserved()) if torch.cuda.is_available() else 0, + ) + + +def _current_process_gpu_memory_bytes(device_index: int) -> int: + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + pid = os.getpid() + processes = list(pynvml.nvmlDeviceGetComputeRunningProcesses(handle)) + with suppress(Exception): + processes.extend(pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)) + for process in processes: + if int(process.pid) == pid: + return int(process.usedGpuMemory) + except Exception: + return 0 + return 0 + + +def _distributed_memory_metadata(tracker: _CudaMemoryTracker) -> dict[str, float]: + values = torch.tensor( + [ + tracker.allocated_peak_bytes, + tracker.reserved_peak_bytes, + tracker.process_peak_bytes, + ], + device="cuda", + dtype=torch.float64, + ) + dist.all_reduce(values, op=dist.ReduceOp.MAX) + return { + "peak_memory_allocated_gb": round(float(values[0].item()) / 1024**3, 3), + "peak_memory_reserved_gb": round(float(values[1].item()) / 1024**3, 3), + "peak_memory_process_gb": round(float(values[2].item()) / 1024**3, 3), + "peak_memory_gb": round(float(values[0].item()) / 1024**3, 3), + } + + +def _mean_abs_pct(reference: torch.Tensor, candidate: torch.Tensor) -> float: + reference_fp32 = reference.detach().float() + candidate_fp32 = candidate.detach().float() + return float( + (candidate_fp32 - reference_fp32).abs().mean().item() + / (reference_fp32.abs().mean().item() + 1e-18) + ) + + +def _model_metadata(runtime: object, model_name: str, *, layers: int) -> dict[str, Any]: + from art.megatron.lora import LoRA + + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + config = getattr(model, "config", None) + total_params = sum( + int(param.numel()) for chunk in runtime.model for param in chunk.parameters() + ) + trainable_params = sum( + int(param.numel()) + for chunk in runtime.model + for param in chunk.parameters() + if param.requires_grad + ) + lora_sites = sum( + 1 + for chunk in runtime.model + for module in chunk.modules() + if isinstance(module, LoRA) + ) + local = torch.tensor( + [total_params, trainable_params, lora_sites], + device="cuda", + dtype=torch.float64, + ) + dist.all_reduce(local, op=dist.ReduceOp.MAX) + return { + "model": model_name, + "layers_arg": layers, + "provider_num_layers": getattr(provider, "num_layers", None), + "config_num_layers": getattr(config, "num_layers", None), + "rank_local_param_count": int(local[0].item()), + "rank_local_trainable_param_count": int(local[1].item()), + "rank_local_lora_site_count": int(local[2].item()), + } + + +def _bench( + fn: Callable[[], object], + *, + warmup: int, + repeat: int, + after: Callable[[], object] | None = None, +) -> float: + for _ in range(warmup): + fn() + if after is not None: + after() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + fn() + if after is not None: + after() + stop.record() + torch.cuda.synchronize() + elapsed = torch.tensor(start.elapsed_time(stop) / repeat, device="cuda") + dist.all_reduce(elapsed, op=dist.ReduceOp.MAX) + return round(float(elapsed.item()), 3) + + +def _builtin( + rank: TrainerRank, + prepared: object, + labels: torch.Tensor | None, +) -> torch.Tensor: + from art.megatron.train import _placeholder_attention_mask + + return rank.runtime.model[0]( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + attention_mask=_placeholder_attention_mask(rank.device), + labels=labels, + packed_seq_params=prepared.packed_seq_params, + **rank._handler().get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ), + ) + + +def _full_logits(rank: TrainerRank, prepared: object) -> torch.Tensor: + logits = rank._gather_tensor_parallel_logits(_builtin(rank, prepared, None)) + return _batch_seq_logits(logits, seq_len=int(prepared.tokens.shape[1])) + + +def _target_builtin_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + return _builtin(rank, prepared, _packed_labels(items, prepared)).float().sum() + + +def _target_builtin_masked_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + labels = _packed_labels(items, prepared) + per_token_loss = _builtin(rank, prepared, labels).float().reshape(-1) + valid = labels.reshape(-1) != -100 + return per_token_loss[valid].sum() + per_token_loss.sum() * 0.0 + + +def _target_hidden_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + outputs = rank._project_head(items, prepared, hidden) + losses = [ + -target_logprobs.sum() + for target_logprobs in outputs.target_logprobs + if target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _target_trainer_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + outputs = rank._forward_packed(items, prepared) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _target_requests_loss( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> torch.Tensor: + outputs = rank.forward(requests) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _trainer_topk_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + outputs = rank._forward_packed(items, prepared) + losses = [ + -output.top_k.logprobs.sum() for output in outputs if output.top_k is not None + ] + if not losses: + raise RuntimeError("top_k logprobs were not produced") + return torch.stack(losses).sum() + + +def _topk_requests_loss( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> torch.Tensor: + outputs = rank.forward(requests) + losses = [ + -output.top_k.logprobs.sum() for output in outputs if output.top_k is not None + ] + if not losses: + raise RuntimeError("top_k logprobs were not produced") + return torch.stack(losses).sum() + + +def _training_step( + rank: TrainerRank, + loss_fn: Callable[[], torch.Tensor], + *, + params: AdamParams, + offload_manager: object | None, +) -> dict[str, float]: + if offload_manager is None: + return _training_step_body(rank, loss_fn, params=params) + with offload_manager.job(): # type: ignore[attr-defined] + return _training_step_body(rank, loss_fn, params=params) + + +def _training_step_body( + rank: TrainerRank, + loss_fn: Callable[[], torch.Tensor], + *, + params: AdamParams, +) -> dict[str, float]: + rank.zero_grad() + loss = loss_fn() + loss.backward() + return rank.optim_step(params=params, scale_grads=1.0) + + +def _make_offload_manager(runtime: object) -> object: + from art.megatron.training.streaming_weight_offload import ( + StreamingWeightOffloadConfig, + ) + from art.megatron.training.weight_offload import WeightOffloadManager + + manager = WeightOffloadManager.from_config( + model=getattr(runtime, "model"), + rank=dist.get_rank(), + compile_enabled=bool(getattr(runtime, "transformer_layers_compiled", False)), + offload_between_jobs=True, + streaming_config=StreamingWeightOffloadConfig(enabled=False), + ) + manager.install() + manager.after_job() + return manager + + +def _target_correctness_metrics( + rank: TrainerRank, + items: object, + prepared: object, +) -> dict[str, float]: + for chunk in rank.runtime.model: + chunk.eval() + with torch.no_grad(): + labels = _packed_labels(items, prepared) + native_outputs = rank._forward_native_target_logprobs(items, prepared, labels) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + head_outputs = rank._project_head(items, prepared, hidden) + abs_diff_sum = torch.tensor(0.0, device=rank.device) + reference_abs_sum = torch.tensor(0.0, device=rank.device) + value_count = torch.tensor(0.0, device=rank.device) + max_abs_diff = torch.tensor(0.0, device=rank.device) + for native, candidate in zip( + native_outputs, + head_outputs.target_logprobs, + strict=True, + ): + if native.target_logprobs is None or candidate is None: + continue + diff = (candidate.float() - native.target_logprobs.float()).abs() + if int(diff.numel()) == 0: + continue + abs_diff_sum += diff.sum() + reference_abs_sum += native.target_logprobs.float().abs().sum() + value_count += float(diff.numel()) + max_abs_diff = torch.maximum(max_abs_diff, diff.max()) + sums = torch.stack((abs_diff_sum, reference_abs_sum, value_count)) + dist.all_reduce(sums, op=dist.ReduceOp.SUM) + dist.all_reduce(max_abs_diff, op=dist.ReduceOp.MAX) + mean_abs_pct = float((sums[0] / torch.clamp(sums[1], min=1e-18)).item()) + max_abs = float(max_abs_diff.item()) + return { + "target_hidden_vs_native_mean_abs_pct": mean_abs_pct, + "target_hidden_vs_native_max_abs_diff": max_abs, + "target_hidden_vs_native_value_count": float(sums[2].item()), + } + + +def _adapter_sanity_metrics( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + adapter_slots: int, +) -> dict[str, float]: + target_request = next( + (request for request in requests if request.target_tokens is not None), + None, + ) + if target_request is None: + return {"adapter_sanity_skipped": 1.0} + base_request = ForwardInput( + input_tokens=target_request.input_tokens, + target_tokens=target_request.target_tokens, + checkpoint=None, + ) + slot_request = ForwardInput( + input_tokens=target_request.input_tokens, + target_tokens=target_request.target_tokens, + checkpoint="S0", + ) + for chunk in rank.runtime.model: + chunk.eval() + with torch.no_grad(): + base_output = rank.forward([base_request])[0] + slot_output = rank.forward([slot_request])[0] + if base_output.target_logprobs is None or slot_output.target_logprobs is None: + raise RuntimeError("adapter sanity target outputs were not produced") + output_diff = _mean_abs_pct( + base_output.target_logprobs, + slot_output.target_logprobs, + ) + output_max = float( + (slot_output.target_logprobs.float() - base_output.target_logprobs.float()) + .abs() + .max() + .item() + ) + + slot_params = rank._checkpoint_slot_params("S0") + other_params = rank._checkpoint_slot_params("S1") if adapter_slots > 1 else [] + before = [param.detach().clone() for param in slot_params] + other_before = [param.detach().clone() for param in other_params] + for chunk in rank.runtime.model: + chunk.train() + rank.zero_grad() + loss = _target_requests_loss(rank, [slot_request]) + loss.backward() + grad_sq = torch.tensor(0.0, device=rank.device) + for param in slot_params: + if param.grad is not None: + grad_sq = grad_sq + param.grad.detach().float().square().sum() + grad_norm = torch.sqrt(grad_sq) + rank.optim_step(params=params, checkpoints=["S0"]) + slot_delta = sum( + float((param.detach().float() - old.float()).abs().sum().item()) + for param, old in zip(slot_params, before, strict=True) + ) + other_delta = sum( + float((param.detach().float() - old.float()).abs().sum().item()) + for param, old in zip(other_params, other_before, strict=True) + ) + values = torch.tensor( + [output_diff, output_max, float(grad_norm.item()), slot_delta, other_delta], + device=rank.device, + ) + dist.all_reduce(values, op=dist.ReduceOp.MAX) + return { + "adapter_sanity_output_mean_abs_pct": float(values[0].item()), + "adapter_sanity_output_max_abs_diff": float(values[1].item()), + "adapter_sanity_grad_norm": float(values[2].item()), + "adapter_sanity_stepped_slot_delta": float(values[3].item()), + "adapter_sanity_unselected_slot_delta": float(values[4].item()), + } + + +def _runtime_output_shape(runtime: object) -> tuple[int, int, int]: + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + hidden_size = int( + getattr(provider, "hidden_size", None) + or getattr(getattr(model, "config", None), "hidden_size", 0) + ) + vocab_size = int( + getattr(getattr(model, "config", None), "padded_vocab_size", None) + or getattr(model, "vocab_size", 0) + ) + dtype_size = next(getattr(runtime, "model")[0].parameters()).element_size() + if hidden_size <= 0 or vocab_size <= 0: + raise RuntimeError( + f"could not infer output shape: hidden_size={hidden_size}, " + f"vocab_size={vocab_size}" + ) + return hidden_size, vocab_size, dtype_size + + +def _request_output_gb( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> float: + return ( + sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ) + / 1024**3 + ) + + +def _request_output_bytes( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> int: + seq_len = int(request.input_tokens.numel()) + bytes_total = 0 + if request.target_tokens is not None: + bytes_total += int(request.target_tokens.numel()) * 4 + if request.top_k is not None: + bytes_total += seq_len * int(request.top_k) * (4 + 8) + if request.logits: + bytes_total += seq_len * vocab_size * dtype_size + if request.hidden_states: + bytes_total += seq_len * hidden_size * dtype_size + return bytes_total + + +def _logits_requests( + requests: Sequence[ForwardInput[torch.Tensor, None, None, None]], +) -> list[ForwardInput[None, None, torch.Tensor, None]]: + return [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ] + + +def _rate_units( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + stats: dict[str, int | str], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> dict[str, int]: + return { + "packed_tokens": int(stats.get("packed_tokens", 0)), + "logical_tokens": int(stats.get("logical_tokens", 0)), + "target_values": _target_value_count(requests), + "output_bytes": sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ), + } + + +def _target_value_count( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> int: + count = 0 + for request in requests: + if request.target_tokens is not None: + count += int((request.target_tokens != -100).sum().item()) + return count + + +def _rate_metrics( + results: dict[str, float], + units_by_name: dict[str, dict[str, int]], +) -> dict[str, float]: + suffixes = { + "packed_tokens": "packed_tok_s", + "logical_tokens": "logical_tok_s", + "target_values": "target_logprob_s", + "output_bytes": "output_gb_s", + } + metrics: dict[str, float] = {} + for key, ms in results.items(): + if ms <= 0: + continue + name = key.removesuffix("_ms") + units = units_by_name.get(name, {}) + for unit_key, suffix in suffixes.items(): + value = int(units.get(unit_key, 0)) + if value <= 0: + continue + scale = 1024**3 if unit_key == "output_bytes" else 1 + metrics[f"{name}_{suffix}"] = round(value * 1000.0 / ms / scale, 3) + return metrics + + +def _packed_labels(items: object, prepared: object) -> torch.Tensor: + labels = torch.full_like(prepared.tokens, -100) + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + continue + labels.reshape(-1)[positions.to(device=labels.device)] = item.labels.to( + device=labels.device + ).index_select(0, source_positions.to(device=labels.device)) + return labels + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_review_perf.py b/dev/trainer_rank_review_perf.py new file mode 100644 index 000000000..2af0bf414 --- /dev/null +++ b/dev/trainer_rank_review_perf.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +import json +from pathlib import Path +import time + +import torch +from torch.nn.attention.flex_attention import BlockMask +from torch.nn.attention.flex_attention import create_block_mask as torch_block_mask +import typer + +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) +from art.megatron.context_parallel.builder import build_shared_prefix_attention_spec +from art.megatron.context_parallel.runtime import ( + _RUNTIME_PLAN_CACHE, + get_or_build_runtime_plan, + make_runtime_key, +) +from art.megatron.context_parallel.types import ( + ContextParallelConfig, + FlexMaskSpec, + ParallelTopology, +) +from art.megatron.flex_attn.attention import FlexAttentionWrapper +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.megatron.shared_prefix_state import create_shared_prefix_state + + +def main( + workload: str = "austin_198k", + max_depth: int = 1, + cp_size: int = 4, + block_size: int = 128, + prefix_families: int = 4, + prefix_len: int = 1024, + mid_prefixes_per_family: int = 1, + mid_prefix_len: int = 0, + branches_per_prefix: int = 8, + completion_len: int = 128, + warmup: int = 3, + repeat: int = 10, + shape_variants: int = 4, + validate_torch: bool = True, + run_flex: bool = True, + flex_token_cap: int = 8192, + flex_heads: int = 2, + flex_head_dim: int = 64, + flex_mask_variants: str = "current,flat_pair,token_group,local_or_flat_pair", + output_jsonl: Path = Path(".local/trainer_rank_review/block_mask_flex.jsonl"), +) -> None: + if warmup < 0 or repeat < 1: + raise ValueError("warmup must be >= 0 and repeat must be >= 1") + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + + pack = _pack_workload( + workload=workload, + max_depth=max_depth, + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + config = ContextParallelConfig(block_size=block_size) + topology = ParallelTopology(cp=cp_size) + base = { + "workload": workload, + "max_depth": max_depth, + "cp_size": cp_size, + "block_size": block_size, + "packed_tokens": int(pack.tokens.numel()), + "logical_tokens": _logical_tokens(pack), + "warmup": warmup, + "repeat": repeat, + } + + plan, plan_ms = _bench_cpu( + lambda: _build_cp_plan(pack, spec, topology, config), + warmup=warmup, + repeat=repeat, + before_each=_RUNTIME_PLAN_CACHE.clear, + ) + _write( + output_jsonl, + { + **base, + "case": "cp_planning_cold", + "ms": plan_ms, + **_plan_stats(plan), + }, + ) + + cached_plan, cached_plan_ms = _bench_cpu( + lambda: _build_cp_plan(pack, spec, topology, config), + warmup=warmup, + repeat=repeat, + ) + _write( + output_jsonl, + { + **base, + "case": "cp_planning_cached", + "ms": cached_plan_ms, + **_plan_stats(cached_plan), + }, + ) + + stage_masks, mask_ms = _bench_cpu( + lambda: _build_stage_masks(pack, plan, config), + warmup=warmup, + repeat=repeat, + ) + masks = tuple(mask for mask, _ in stage_masks) + if validate_torch: + for mask, slices in stage_masks: + _assert_matches_torch_block_mask(mask, slices=slices) + _write( + output_jsonl, + { + **base, + "case": "block_mask_build", + "ms": mask_ms, + **_mask_stats(masks), + }, + ) + + if run_flex: + for record in _flex_records( + pack, + warmup=warmup, + repeat=repeat, + token_cap=flex_token_cap, + heads=flex_heads, + head_dim=flex_head_dim, + variants=_csv_values(flex_mask_variants), + ): + _write(output_jsonl, {**base, **record}) + + for variant in range(shape_variants): + variant_pack = _pack_workload( + workload="regular", + max_depth=max_depth, + prefix_families=prefix_families, + prefix_len=prefix_len + variant * 17, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len + variant * 3, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len + variant * 11, + ) + variant_spec = build_shared_prefix_attention_spec( + group_ids=variant_pack.group_ids, + parent_ids=variant_pack.parent_ids, + ) + variant_plan, variant_plan_ms = _bench_cpu( + lambda pack=variant_pack, spec=variant_spec: _build_cp_plan( + pack, + spec, + topology, + config, + ), + warmup=0, + repeat=1, + before_each=_RUNTIME_PLAN_CACHE.clear, + ) + variant_stage_masks, variant_mask_ms = _bench_cpu( + lambda pack=variant_pack, plan=variant_plan: _build_stage_masks( + pack, + plan, + config, + ), + warmup=0, + repeat=1, + ) + variant_masks = tuple(mask for mask, _ in variant_stage_masks) + if validate_torch: + for mask, slices in variant_stage_masks: + _assert_matches_torch_block_mask(mask, slices=slices) + _write( + output_jsonl, + { + **base, + "case": "shape_variant", + "variant": variant, + "variant_packed_tokens": int(variant_pack.tokens.numel()), + "variant_logical_tokens": _logical_tokens(variant_pack), + "cp_planning_ms": variant_plan_ms, + "block_mask_build_ms": variant_mask_ms, + **_plan_stats(variant_plan), + **_mask_stats(variant_masks), + }, + ) + + print(f"wrote review perf records to {output_jsonl}", flush=True) + + +def _pack_workload( + *, + workload: str, + max_depth: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> SharedPrefixPack: + sequences = ( + _austin_sequences() + if workload == "austin_198k" + else _austin_varied_sequences() + if workload == "austin_varied" + else _regular_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + ) + return pack_shared_prefixes(sequences, max_depth=max_depth) + + +def _austin_sequences() -> tuple[torch.Tensor, ...]: + return tuple( + torch.cat( + ( + _tokens(family * 10_000_019, 5000), + _tokens(family * 10_000_019 + branch * 1009 + 17, 100), + ) + ) + for family in range(30) + for branch in range(16) + ) + + +def _austin_varied_sequences() -> tuple[torch.Tensor, ...]: + sequences: list[torch.Tensor] = [] + for family in range(30): + family_base = family * 10_000_019 + prefix_len = 4500 + ((family * 137) % 1001) + root = _tokens(family_base, prefix_len) + branch_count = 10 + ((family * 7) % 13) + for branch in range(branch_count): + completion_len = 32 + ((family * 19 + branch * 23) % 145) + sequences.append( + torch.cat( + ( + root, + _tokens( + family_base + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + return tuple(sequences) + + +def _regular_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[torch.Tensor, ...]: + sequences = [] + for family in range(max(1, prefix_families)): + family_base = family * 10_000_019 + root = _tokens(family_base, max(1, prefix_len)) + for mid in range(max(1, mid_prefixes_per_family)): + mid_prefix = _tokens( + family_base + 1_000_003 + mid * 100_003, + max(0, mid_prefix_len), + ) + prefix = torch.cat((root, mid_prefix)) + for branch in range(max(1, branches_per_prefix)): + sequences.append( + torch.cat( + ( + prefix, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + max(1, completion_len), + ), + ) + ) + ) + return tuple(sequences) + + +def _tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _build_cp_plan( + pack: SharedPrefixPack, + spec: object, + topology: ParallelTopology, + config: ContextParallelConfig, +) -> object: + return get_or_build_runtime_plan( + spec, + topology=topology, + config=config, + runtime_key=make_runtime_key(spec, topology=topology, config=config), + original_seq_len=int(pack.tokens.numel()), + ) + + +def _build_stage_masks( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, +) -> tuple[tuple[BlockMask, tuple[object, ...]], ...]: + masks = [] + context = prepare_block_mask_context( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + for rank_plan in plan.rank_plans: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + mask = build_block_mask_from_context( + FlexMaskSpec( + q_len=stage.q_len, + k_len=stage.k_len, + block_size=config.block_size, + slices=stage.slices, + exact_mask=stage.mask_metadata, + ), + context=context, + device=torch.device("cpu"), + validate=False, + ) + if mask is not None: + masks.append((mask, tuple(stage.slices))) + return tuple(masks) + + +def _flex_records( + pack: SharedPrefixPack, + *, + warmup: int, + repeat: int, + token_cap: int, + heads: int, + head_dim: int, + variants: Sequence[str], +) -> list[dict[str, object]]: + if not torch.cuda.is_available(): + return [{"case": "flex_attention_fwd_bwd", "skipped": "cuda_unavailable"}] + if int(pack.tokens.numel()) > int(token_cap): + return [ + { + "case": "flex_attention_fwd_bwd", + "skipped": "packed_tokens_exceed_flex_token_cap", + "flex_token_cap": int(token_cap), + } + ] + device = torch.device("cuda") + group_ids = pack.group_ids.to(device) + parent_ids = pack.parent_ids.to(device) + attention_state = create_shared_prefix_state( + group_ids, + parent_ids, + target_device=device, + ) + shape = (1, int(heads), int(pack.tokens.numel()), int(head_dim)) + records: list[dict[str, object]] = [] + block_masks = _flex_mask_variants( + attention_state.block_mask, + pack, + variants=variants, + device=device, + ) + for variant, block_mask in block_masks: + q = torch.randn(shape, device=device, dtype=torch.bfloat16, requires_grad=True) + k = torch.randn(shape, device=device, dtype=torch.bfloat16, requires_grad=True) + v = torch.randn(shape, device=device, dtype=torch.bfloat16, requires_grad=True) + wrapper = FlexAttentionWrapper() + + def step() -> None: + q.grad = None + k.grad = None + v.grad = None + out = wrapper( + q, + k, + v, + block_mask=block_mask, + scale=float(head_dim) ** -0.5, + enable_gqa=False, + ) + out.float().sum().backward() + + try: + torch.cuda.synchronize() + first_started = time.perf_counter() + step() + torch.cuda.synchronize() + first_call_ms = round((time.perf_counter() - first_started) * 1000.0, 3) + ms = _bench_cuda(step, warmup=warmup, repeat=repeat) + except Exception as exc: + torch.cuda.empty_cache() + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "compile_error": type(exc).__name__, + "compile_error_message": str(exc).splitlines()[0][:500], + "flex_heads": heads, + "flex_head_dim": head_dim, + } + ) + continue + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "first_call_ms": first_call_ms, + "ms": ms, + "packed_tok_s": round(int(pack.tokens.numel()) * 1000.0 / ms, 3), + "flex_heads": heads, + "flex_head_dim": head_dim, + "peak_memory_gb": round(torch.cuda.max_memory_allocated() / 1024**3, 3), + } + ) + return records + + +def _flex_mask_variants( + block_mask: BlockMask, + pack: SharedPrefixPack, + *, + variants: Sequence[str], + device: torch.device, +) -> tuple[tuple[str, BlockMask], ...]: + group_ids = pack.group_ids[0].to(device=device, dtype=torch.long) + can_attend = _group_can_attend(pack).to(device=device) + token_group_can_attend = can_attend.index_select(0, group_ids) + stride = int(can_attend.shape[1]) + can_attend_flat = can_attend.reshape(-1) + out = [] + for variant in variants: + if variant == "current": + out.append((variant, block_mask)) + continue + if variant == "flat_pair": + + def mask_mod(batch_idx, head_idx, query_idx, kv_idx): + del batch_idx, head_idx + q_group = group_ids[query_idx] + k_group = group_ids[kv_idx] + return (query_idx >= kv_idx) & can_attend_flat[ + q_group * stride + k_group + ] + + elif variant == "token_group": + + def mask_mod(batch_idx, head_idx, query_idx, kv_idx): + del batch_idx, head_idx + k_group = group_ids[kv_idx] + return (query_idx >= kv_idx) & token_group_can_attend[ + query_idx, k_group + ] + + elif variant == "local_or_flat_pair": + + def mask_mod(batch_idx, head_idx, query_idx, kv_idx): + del batch_idx, head_idx + q_group = group_ids[query_idx] + k_group = group_ids[kv_idx] + allowed = (q_group == k_group) | can_attend_flat[ + q_group * stride + k_group + ] + return (query_idx >= kv_idx) & allowed + + else: + raise ValueError(f"unknown flex_mask_variant {variant!r}") + out.append((variant, _replace_block_mask_mod(block_mask, mask_mod))) + return tuple(out) + + +def _group_can_attend(pack: SharedPrefixPack) -> torch.Tensor: + group_ids = pack.group_ids[0].to(dtype=torch.long).cpu() + parent_ids = pack.parent_ids[0].to(dtype=torch.long).cpu() + max_group = int(group_ids.max().item()) if int(group_ids.numel()) else 0 + parents = [0 for _ in range(max_group + 1)] + for group, parent in zip(group_ids.tolist(), parent_ids.tolist(), strict=True): + if int(group) >= 0: + parents[int(group)] = max(0, int(parent)) + can_attend = torch.zeros((max_group + 1, max_group + 1), dtype=torch.bool) + for group in range(1, max_group + 1): + current = group + seen: set[int] = set() + while current > 0 and current not in seen: + seen.add(current) + can_attend[group, current] = True + parent = parents[current] + if parent == current: + break + current = parent + return can_attend + + +def _replace_block_mask_mod(block_mask: BlockMask, mask_mod: object) -> BlockMask: + return BlockMask( + seq_lengths=block_mask.seq_lengths, + kv_num_blocks=block_mask.kv_num_blocks, + kv_indices=block_mask.kv_indices, + full_kv_num_blocks=block_mask.full_kv_num_blocks, + full_kv_indices=block_mask.full_kv_indices, + q_num_blocks=block_mask.q_num_blocks, + q_indices=block_mask.q_indices, + full_q_num_blocks=block_mask.full_q_num_blocks, + full_q_indices=block_mask.full_q_indices, + BLOCK_SIZE=block_mask.BLOCK_SIZE, + mask_mod=mask_mod, + ) + + +def _bench_cpu( + fn: Callable[[], object], + *, + warmup: int, + repeat: int, + before_each: Callable[[], object] | None = None, +) -> tuple[object, float]: + result = None + for _ in range(warmup): + if before_each is not None: + before_each() + result = fn() + elapsed = [] + for _ in range(repeat): + if before_each is not None: + before_each() + start = time.perf_counter() + result = fn() + elapsed.append((time.perf_counter() - start) * 1000.0) + assert result is not None + return result, round(sum(elapsed) / len(elapsed), 3) + + +def _bench_cuda(fn: Callable[[], object], *, warmup: int, repeat: int) -> float: + torch.cuda.reset_peak_memory_stats() + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + fn() + stop.record() + torch.cuda.synchronize() + return round(float(start.elapsed_time(stop)) / repeat, 3) + + +def _plan_stats(plan: object) -> dict[str, int]: + stage_count = 0 + remote_stage_count = 0 + mask_stage_count = 0 + for rank_plan in plan.rank_plans: + for stage in rank_plan.stage_plans: + stage_count += 1 + remote_stage_count += int(not stage.is_local_stage) + mask_stage_count += int(stage.mask_metadata is not None) + return { + "rank_count": len(plan.rank_plans), + "stage_count": stage_count, + "remote_stage_count": remote_stage_count, + "mask_stage_count": mask_stage_count, + } + + +def _mask_stats(masks: Sequence[BlockMask]) -> dict[str, int]: + return { + "mask_count": len(masks), + "partial_kv_blocks": sum(_block_count(mask, "kv_num_blocks") for mask in masks), + "full_kv_blocks": sum( + _block_count(mask, "full_kv_num_blocks") for mask in masks + ), + "partial_q_blocks": sum(_block_count(mask, "q_num_blocks") for mask in masks), + "full_q_blocks": sum(_block_count(mask, "full_q_num_blocks") for mask in masks), + } + + +def _block_count(block_mask: BlockMask, name: str) -> int: + counts = getattr(block_mask, name) + return 0 if counts is None else int(counts.sum().item()) + + +def _assert_matches_torch_block_mask( + block_mask: BlockMask, + *, + slices: Sequence[object] = (), +) -> None: + q_len, k_len = block_mask.seq_lengths + reference = torch_block_mask( + _slice_mask_mod(block_mask.mask_mod, slices), + B=int(block_mask.kv_num_blocks.shape[0]), + H=1, + Q_LEN=q_len, + KV_LEN=k_len, + device="cpu", + BLOCK_SIZE=block_mask.BLOCK_SIZE, + ) + for counts_name, indices_name in ( + ("kv_num_blocks", "kv_indices"), + ("full_kv_num_blocks", "full_kv_indices"), + ("q_num_blocks", "q_indices"), + ("full_q_num_blocks", "full_q_indices"), + ): + actual = _block_entries(block_mask, counts_name, indices_name) + expected = _block_entries(reference, counts_name, indices_name) + if actual != expected: + raise AssertionError(f"{counts_name}/{indices_name} mismatch") + + +def _slice_mask_mod(mask_mod: object, slices: Sequence[object]) -> object: + if not slices: + return mask_mod + + def sliced_mask_mod( + batch_idx: torch.Tensor, + head_idx: torch.Tensor, + query_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + in_slice = (query_idx < 0) & (kv_idx < 0) + for slice_ in slices: + in_slice |= ( + (query_idx >= int(slice_.q_range.start)) + & (query_idx < int(slice_.q_range.end)) + & (kv_idx >= int(slice_.k_range.start)) + & (kv_idx < int(slice_.k_range.end)) + ) + return in_slice & mask_mod(batch_idx, head_idx, query_idx, kv_idx) + + return sliced_mask_mod + + +def _block_entries( + block_mask: BlockMask, + counts_name: str, + indices_name: str, +) -> set[tuple[int, int, int, int]]: + counts = getattr(block_mask, counts_name) + indices = getattr(block_mask, indices_name) + if counts is None or indices is None: + return set() + entries = set() + for batch_index in range(int(counts.shape[0])): + for head_index in range(int(counts.shape[1])): + for block_index in range(int(counts.shape[2])): + block_count = int(counts[batch_index, head_index, block_index]) + for other_block in indices[ + batch_index, + head_index, + block_index, + :block_count, + ].tolist(): + entries.add( + ( + batch_index, + head_index, + block_index, + int(other_block), + ) + ) + return entries + + +def _logical_tokens(pack: SharedPrefixPack) -> int: + return sum(int(positions.numel()) for positions in pack.positions_by_sequence) + + +def _csv_values(value: str) -> tuple[str, ...]: + values = tuple(part.strip() for part in value.split(",") if part.strip()) + if not values: + raise ValueError("CSV option must contain at least one value") + return values + + +def _write(path: Path, payload: dict[str, object]) -> None: + line = json.dumps(payload, sort_keys=True) + with path.open("a", encoding="utf-8") as output: + output.write(line + "\n") + print(line, flush=True) + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_topology_check.py b/dev/trainer_rank_topology_check.py new file mode 100644 index 000000000..147a56cdd --- /dev/null +++ b/dev/trainer_rank_topology_check.py @@ -0,0 +1,1122 @@ +from __future__ import annotations + +from dataclasses import dataclass +import json +import os +import time + +import torch +import torch.distributed as dist +import typer + +from art.megatron.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + _empty_logits_like_positions, + _gather_target_logprobs, + _language_model, + _pack_forward_items, + _PackedForwardBatch, + _select_positions, +) + + +@dataclass +class CheckOutput: + source_positions: torch.Tensor + target_logprobs: torch.Tensor | None + top_k: TopK | None + logits: torch.Tensor | None + hidden_states: torch.Tensor | None + + +@dataclass(frozen=True) +class DiffStats: + max_abs_diff: float = 0.0 + mean_abs_pct: float = 0.0 + + def merge(self, other: DiffStats) -> DiffStats: + return DiffStats( + max_abs_diff=max(self.max_abs_diff, other.max_abs_diff), + mean_abs_pct=max(self.mean_abs_pct, other.mean_abs_pct), + ) + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + head_chunk_a: int = 17, + head_chunk_b: int = 512, + max_prefix_depth: int = 1, + request_case: str = "shared", + stress_tokens: int = 0, + max_unpacked_output_gb: float = 0.25, + debug_output: str = "none", + compare_independent: bool = False, + compare_same_layout: bool = False, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + torch.manual_seed(1234) + provider_configure = ( + (lambda provider: setattr(provider, "num_layers", layers)) + if layers > 0 + else None + ) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=provider_configure, + print_env=dist.get_rank() == 0, + ) + for chunk in runtime.model: + chunk.eval() + + requests = ( + _stress_requests(stress_tokens) + if stress_tokens > 0 + else _requests(request_case) + ) + requests = _debug_output_requests(requests, debug_output) + unpacked_output_gb = _estimate_unpacked_output_gb(requests, runtime) + if max_unpacked_output_gb > 0 and unpacked_output_gb > max_unpacked_output_gb: + if dist.get_rank() == 0: + print( + json.dumps( + { + "world": dist.get_world_size(), + "dp": int(ps.get_data_parallel_world_size()), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "stress_tokens": stress_tokens, + "estimated_unpacked_output_gb": round( + unpacked_output_gb, 3 + ), + "max_unpacked_output_gb": max_unpacked_output_gb, + "skipped": "unpacked_output_cap", + }, + sort_keys=True, + ), + flush=True, + ) + dist.barrier() + return + dp_rank = int(ps.get_data_parallel_rank()) + dp_size = int(ps.get_data_parallel_world_size()) + local_pairs = [ + (index, request) + for index, request in enumerate(requests) + if index % dp_size == dp_rank + ] + local_requests = [request for _, request in local_pairs] + + rank_a = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_a, + shared_prefix_max_depth=max_prefix_depth, + ) + rank_b = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_b, + shared_prefix_max_depth=max_prefix_depth, + ) + independent_outputs: list[CheckOutput] | None = None + same_layout_outputs: list[CheckOutput] | None = None + + torch.cuda.reset_peak_memory_stats() + diff_stats = DiffStats() + with torch.no_grad(): + started_at = time.perf_counter() + if request_case == "target_only": + _debug("forward-target-only") + outputs_a = list(rank_a.forward(local_requests)) + outputs_b = list(rank_b.forward(local_requests)) + oracle_outputs, actual_source_positions = _packed_oracle( + rank_a, local_requests + ) + elif stress_tokens > 0: + _debug("forward-a") + outputs_a = list(rank_a.forward(local_requests)) + outputs_b = outputs_a + actual_source_positions = _source_positions(rank_a, local_requests) + oracle_outputs = [ + _as_check_output(source_positions, output) + for source_positions, output in zip( + actual_source_positions, + outputs_a, + strict=True, + ) + ] + else: + _debug("forward-shared") + ( + outputs_a, + outputs_b, + oracle_outputs, + actual_source_positions, + ) = _shared_hidden_check(rank_a, rank_b, local_requests) + if compare_independent and request_case in {"shared", "unique", "deep"}: + independent_outputs = _independent_check_outputs( + rank_a, local_requests + ) + if int(ps.get_context_parallel_world_size()) <= 1: + for index, (actual, independent) in enumerate( + zip(outputs_a, independent_outputs, strict=True) + ): + diff_stats = diff_stats.merge( + _assert_close( + actual, + independent, + f"independent[{index}]", + ), + ) + if compare_same_layout and request_case in {"shared", "unique", "deep"}: + same_layout_outputs = _same_layout_check_outputs( + rank_a, + local_requests, + ) + for index, (actual, same_layout) in enumerate( + zip(outputs_a, same_layout_outputs, strict=True) + ): + diff_stats = diff_stats.merge( + _assert_close( + actual, + same_layout, + f"same_layout[{index}]", + ), + ) + _debug("compare") + elapsed_s = time.perf_counter() - started_at + + peak_memory_gb = torch.tensor( + torch.cuda.max_memory_allocated() / 1024**3, + device=rank_a.device, + ) + for index, (actual, chunked, oracle) in enumerate( + zip(outputs_a, outputs_b, oracle_outputs, strict=True) + ): + if int(oracle.source_positions.numel()) == 0: + continue + diff_stats = diff_stats.merge( + _assert_close(actual, chunked, f"chunk[{index}]"), + ) + diff_stats = diff_stats.merge( + _assert_close(actual, oracle, f"oracle[{index}]"), + ) + + diff_tensor = torch.tensor( + [diff_stats.max_abs_diff, diff_stats.mean_abs_pct], + device=rank_a.device, + ) + dist.all_reduce(diff_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(peak_memory_gb, op=dist.ReduceOp.MAX) + max_diff_value = float(diff_tensor[0].item()) + mean_abs_pct_value = float(diff_tensor[1].item()) + records = _records( + local_pairs=local_pairs, + actual_outputs=outputs_a, + actual_source_positions=actual_source_positions, + oracle_outputs=oracle_outputs, + independent_outputs=independent_outputs, + rank=int(dist.get_rank()), + dp=dp_rank, + tp=int(ps.get_tensor_model_parallel_rank()), + cp=int(ps.get_context_parallel_rank()), + ) + gathered: list[list[dict[str, object]] | None] = [None] * dist.get_world_size() + _debug("all-gather") + dist.all_gather_object(gathered, records) + _debug("reconstruct") + reconstruction_error: str | None = None + if dist.get_rank() == 0: + seen = { + record["input_index"] + for rank_records in gathered + for record in rank_records or [] + } + if seen != set(range(len(requests))): + reconstruction_error = f"DP reconstruction missed inputs: {seen}" + else: + try: + reconstructed_stats = _assert_reconstructed(gathered, requests) + max_diff_value = max( + max_diff_value, + reconstructed_stats.max_abs_diff, + ) + mean_abs_pct_value = max( + mean_abs_pct_value, + reconstructed_stats.mean_abs_pct, + ) + except AssertionError as exc: + reconstruction_error = str(exc) + if reconstruction_error is None: + print( + json.dumps( + { + "world": dist.get_world_size(), + "dp": dp_size, + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "mean_abs_pct": mean_abs_pct_value, + "max_abs_diff": max_diff_value, + "records": sum( + len(rank_records or []) for rank_records in gathered + ), + "same_layout": compare_same_layout, + "stress_tokens": stress_tokens, + "estimated_unpacked_output_gb": round( + unpacked_output_gb, 3 + ), + "elapsed_s": round(elapsed_s, 3), + "peak_memory_gb": round(float(peak_memory_gb.item()), 3), + }, + sort_keys=True, + ), + flush=True, + ) + errors = [reconstruction_error] + dist.broadcast_object_list(errors, src=0) + if errors[0] is not None: + raise AssertionError(errors[0]) + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _requests( + request_case: str = "shared", +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if request_case not in {"shared", "target_only", "unique", "deep"}: + raise ValueError( + "request_case must be 'shared', 'target_only', 'unique', or 'deep'" + ) + rows = [ + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([11, 12, 13, 14, 24, 25]), + torch.tensor([11, 12, 13, 14, 24, 26]), + torch.tensor([11, 12, 13, 27]), + torch.tensor([31, 32, 33, 34]), + torch.tensor([31, 32, 33, 35]), + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([41, 42, 43]), + torch.tensor([41, 42, 44, 45]), + torch.tensor([51, 52, 53, 54, 55]), + torch.tensor([61, 62, 63]), + torch.tensor([61, 62, 64, 65]), + torch.tensor([71, 72]), + torch.tensor([81, 82, 83, 84]), + torch.tensor([91, 92, 93]), + torch.tensor([101, 102, 103, 104, 105]), + ] + if request_case == "deep": + rows = _deep_rows() + if request_case == "unique": + rows = [row + 1000 * index for index, row in enumerate(rows)] + if request_case == "target_only": + target_only_labels = [_labels(row, 0) for row in rows] + target_only_labels[0][2] = -100 + target_only_labels[3][1] = -100 + target_only_labels[10][0] = -100 + return [ + ForwardInput(input_tokens=row, target_tokens=label) + for row, label in zip(rows, target_only_labels, strict=True) + ] + + labels = [_labels(row, offset) for offset, row in enumerate(rows)] + labels[0][2] = -100 + labels[3][1] = -100 + labels[10][0] = -100 + multi_labels = torch.stack((labels[1], (labels[1] + 17) % 1000), dim=1) + multi_labels[2, 1] = -100 + requests = [] + for mask, row in enumerate(rows): + target_tokens = None + if mask & 1: + target_tokens = multi_labels if mask == 1 else labels[mask] + requests.append( + ForwardInput( + input_tokens=row, + target_tokens=target_tokens, + top_k=3 if mask & 2 else None, + logits=bool(mask & 4), + hidden_states=bool(mask & 8), + ) + ) + return requests + + +def _debug_output_requests( + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + debug_output: str, +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if debug_output == "none": + return requests + if debug_output == "hidden": + return [ + ForwardInput(input_tokens=request.input_tokens, hidden_states=True) + for request in requests + ] + if debug_output == "logits": + return [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ] + raise ValueError("debug_output must be 'none', 'hidden', or 'logits'") + + +def _deep_rows() -> list[torch.Tensor]: + return [ + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([11, 12, 13, 14, 15, 16, 18]), + torch.tensor([11, 12, 13, 14, 15, 19]), + torch.tensor([11, 12, 13, 14, 20]), + torch.tensor([11, 12, 21]), + torch.tensor([31, 32, 33, 34, 35]), + torch.tensor([31, 32, 33, 34, 36]), + torch.tensor([31, 32, 33, 37]), + torch.tensor([41, 42, 43]), + torch.tensor([41, 42, 44]), + torch.tensor([51, 52, 53, 54]), + torch.tensor([61, 62]), + torch.tensor([71, 72, 73, 74, 75]), + torch.tensor([71, 72, 73, 76]), + torch.tensor([81]), + torch.tensor([91, 92, 93]), + ] + + +def _stress_requests( + token_count: int, +) -> list[ForwardInput[None, None, None, torch.Tensor]]: + if token_count < 8: + raise ValueError("stress_tokens must be >= 8") + prefix_len = token_count // 2 + tail_len = max(1, token_count // 4) + prefix = _stress_tokens(0, prefix_len) + return [ + ForwardInput( + input_tokens=torch.cat((prefix, _stress_tokens(10_000, tail_len))), + hidden_states=True, + ), + ForwardInput( + input_tokens=torch.cat((prefix, _stress_tokens(20_000, tail_len))), + hidden_states=True, + ), + ForwardInput(input_tokens=_stress_tokens(30_000, tail_len), hidden_states=True), + ] + + +def _stress_tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _estimate_unpacked_output_gb( + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + runtime: object, +) -> float: + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + hidden_size = int( + getattr(provider, "hidden_size", None) + or getattr(getattr(model, "config", None), "hidden_size", 0) + ) + vocab_size = int( + getattr(getattr(model, "config", None), "padded_vocab_size", None) + or getattr(model, "vocab_size", 0) + ) + dtype_size = next(getattr(runtime, "model")[0].parameters()).element_size() + bytes_total = sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ) + return bytes_total / 1024**3 + + +def _request_output_bytes( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> int: + seq_len = int(request.input_tokens.numel()) + bytes_total = 0 + if request.target_tokens is not None: + bytes_total += int(request.target_tokens.numel()) * 4 + if request.top_k is not None: + bytes_total += seq_len * int(request.top_k) * (4 + 8) + if request.logits: + bytes_total += seq_len * vocab_size * dtype_size + if request.hidden_states: + bytes_total += seq_len * hidden_size * dtype_size + return bytes_total + + +def _debug(label: str) -> None: + if os.environ.get("TRAINER_RANK_CHECK_DEBUG") != "1": + return + print(f"[rank{dist.get_rank()}] {label}", flush=True) + + +def _labels(tokens: torch.Tensor, offset: int) -> torch.Tensor: + return ((tokens * 7 + 3 + offset) % 1000).to(dtype=torch.long) + + +def _packed_oracle( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[list[CheckOutput], tuple[torch.Tensor, ...]]: + items = [rank._forward_item(request) for request in requests] + prepared = rank._prepare_packed_forward( + _pack_forward_items(items, max_depth=rank.shared_prefix_max_depth) + ) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + return ( + _packed_oracle_from_hidden(rank, items, prepared, hidden), + prepared.source_positions_by_item, + ) + + +def _shared_hidden_check( + rank_a: TrainerRank, + rank_b: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[ + list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + list[CheckOutput], + tuple[torch.Tensor, ...], +]: + items = [rank_a._forward_item(request) for request in requests] + prepared = rank_a._prepare_packed_forward( + _pack_forward_items(items, max_depth=rank_a.shared_prefix_max_depth) + ) + hidden = rank_a._gather_sequence_parallel_hidden(rank_a._decoder_hidden(prepared)) + outputs_a = _outputs_from_hidden(rank_a, items, prepared, hidden) + outputs_b = _outputs_from_hidden(rank_b, items, prepared, hidden) + oracle = _packed_oracle_from_hidden(rank_a, items, prepared, hidden) + return ( + outputs_a, + outputs_b, + oracle, + prepared.source_positions_by_item, + ) + + +def _independent_check_outputs( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> list[CheckOutput]: + outputs: list[CheckOutput] = [] + for request in requests: + source_positions = _source_positions(rank, [request])[0] + outputs.append(_as_check_output(source_positions, rank.forward([request])[0])) + return outputs + + +def _same_layout_check_outputs( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> list[CheckOutput]: + items = [rank._forward_item(request) for request in requests] + batch = _pack_forward_items(items, max_depth=rank.shared_prefix_max_depth) + outputs = [] + for index, positions in enumerate(batch.positions_by_item): + mutated = _mutated_batch(batch, keep_positions=positions) + prepared = rank._prepare_packed_forward(mutated) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + mutated_outputs = _outputs_from_hidden(rank, items, prepared, hidden) + outputs.append( + _as_check_output( + prepared.source_positions_by_item[index], + mutated_outputs[index], + ) + ) + return outputs + + +def _mutated_batch( + batch: _PackedForwardBatch, + *, + keep_positions: torch.Tensor, +) -> _PackedForwardBatch: + tokens = batch.tokens.clone() + mutate = torch.ones(int(tokens.shape[1]), dtype=torch.bool, device=tokens.device) + mutate[keep_positions.to(device=tokens.device)] = False + replacement = ( + torch.arange(int(tokens.shape[1]), dtype=tokens.dtype, device=tokens.device) + + 50_000 + ) + tokens[0, mutate] = replacement[mutate] % 100_000 + return _PackedForwardBatch( + tokens=tokens, + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + position_ids=batch.position_ids, + positions_by_item=batch.positions_by_item, + ) + + +def _outputs_from_hidden( + rank: TrainerRank, + items: list[object], + prepared: object, + hidden: torch.Tensor, +) -> list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + head_outputs = rank._project_head(items, prepared, hidden) + outputs = [] + for index, (item, positions) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ): + hidden_states = ( + _select_positions(hidden, positions) if item.request.hidden_states else None + ) + outputs.append( + ForwardOutput( + target_logprobs=head_outputs.target_logprobs[index], + top_k=head_outputs.top_k[index], + logits=head_outputs.logits[index], + hidden_states=hidden_states, + ) + ) + return outputs + + +def _packed_oracle_from_hidden( + rank: TrainerRank, + items: list[object], + prepared: object, + hidden: torch.Tensor, +) -> list[CheckOutput]: + model = _language_model(rank.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + + outputs: list[CheckOutput] = [] + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + needs_projection = ( + item.labels is not None or item.request.logits or item.request.top_k + ) + all_logits = None + if needs_projection: + all_logits = ( + rank._logits_from_hidden_rows( + model, + _select_positions(hidden, positions), + output_weight=output_weight, + ) + if int(positions.numel()) + else _empty_logits_like_positions(positions, model, hidden) + ) + logprobs = ( + None + if all_logits is None + else torch.log_softmax(all_logits.float(), dim=-1) + ) + + target_logprobs = None + if item.labels is not None: + if logprobs is None: + raise RuntimeError("target_logprobs oracle requires logprobs") + labels = item.labels.to(device=logprobs.device).index_select( + 0, source_positions.to(device=logprobs.device) + ) + target_logprobs = _gather_target_logprobs(logprobs, labels) + + top_k = None + if item.request.top_k is not None: + if all_logits is None: + raise RuntimeError("top_k oracle requires logits") + log_z = torch.logsumexp(all_logits.float(), dim=-1) + values, tokens = torch.topk( + all_logits.float(), k=item.request.top_k, dim=-1 + ) + top_k = TopK(logprobs=values - log_z.unsqueeze(1), tokens=tokens) + + hidden_states = None + if item.request.hidden_states: + hidden_states = _select_positions(hidden, positions) + + outputs.append( + CheckOutput( + source_positions=source_positions, + target_logprobs=target_logprobs, + top_k=top_k, + logits=all_logits if item.request.logits else None, + hidden_states=hidden_states, + ) + ) + return outputs + + +def _source_positions( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[torch.Tensor, ...]: + items = [rank._forward_item(request) for request in requests] + prepared = rank._prepare_packed_forward( + _pack_forward_items(items, max_depth=rank.shared_prefix_max_depth) + ) + return prepared.source_positions_by_item + + +def _as_check_output( + source_positions: torch.Tensor, + output: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], +) -> CheckOutput: + return CheckOutput( + source_positions=source_positions, + target_logprobs=output.target_logprobs, + top_k=output.top_k, + logits=output.logits, + hidden_states=output.hidden_states, + ) + + +def _records( + *, + local_pairs: list[ + tuple[ + int, + ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, + ], + ] + ], + actual_outputs: list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + actual_source_positions: tuple[torch.Tensor, ...], + oracle_outputs: list[CheckOutput], + independent_outputs: list[CheckOutput] | None, + rank: int, + dp: int, + tp: int, + cp: int, +) -> list[dict[str, object]]: + records: list[dict[str, object]] = [] + independent_records: list[CheckOutput | None] = ( + independent_outputs + if independent_outputs is not None + else [None] * len(local_pairs) + ) + for local_index, ( + (input_index, _), + actual, + actual_sources, + oracle, + independent, + ) in enumerate( + zip( + local_pairs, + actual_outputs, + actual_source_positions, + oracle_outputs, + independent_records, + strict=True, + ) + ): + records.append( + { + "input_index": input_index, + "local_index": local_index, + "rank": rank, + "dp": dp, + "tp": tp, + "cp": cp, + "actual": _cpu_record(actual_sources, actual), + "oracle": _cpu_record(oracle.source_positions, oracle), + "independent": ( + None + if independent is None + else _cpu_record(independent.source_positions, independent) + ), + } + ) + return records + + +def _cpu_record( + source_positions: torch.Tensor, + output: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + | CheckOutput, +) -> dict[str, torch.Tensor | None]: + return { + "source_positions": source_positions.cpu(), + "target_logprobs": _cpu(output.target_logprobs), + "logits": _cpu(output.logits), + "hidden_states": _cpu(output.hidden_states), + "top_k_logprobs": None if output.top_k is None else _cpu(output.top_k.logprobs), + "top_k_tokens": None if output.top_k is None else _cpu(output.top_k.tokens), + } + + +def _cpu(tensor: torch.Tensor | None) -> torch.Tensor | None: + return None if tensor is None else tensor.detach().cpu() + + +def _assert_reconstructed( + gathered: list[list[dict[str, object]] | None], + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> DiffStats: + diff_stats = DiffStats() + records = [ + record + for rank_records in gathered + for record in rank_records or [] + if record["tp"] == 0 + ] + for input_index, request in enumerate(requests): + _debug(f"reconstruct-input-{input_index}") + actual = [ + record["actual"] + for record in records + if record["input_index"] == input_index + ] + oracle = [ + record["oracle"] + for record in records + if record["input_index"] == input_index + ] + independent = [ + record["independent"] + for record in records + if record["input_index"] == input_index + and record.get("independent") is not None + ] + length = int(request.input_tokens.numel()) + for key in ("target_logprobs", "logits", "hidden_states", "top_k_logprobs"): + _debug(f"reconstruct-input-{input_index}-{key}") + _debug(f"reconstruct-input-{input_index}-{key}-assemble-actual") + actual_value = _assemble(actual, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-actual-" + f"{_tensor_summary(actual_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-assemble-oracle") + oracle_value = _assemble(oracle, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-oracle-" + f"{_tensor_summary(oracle_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-oracle") + diff_stats = diff_stats.merge( + _tensor_diff_value( + actual_value, + oracle_value, + f"reconstructed[{input_index}].{key}", + ), + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-oracle-done") + if independent: + _debug(f"reconstruct-input-{input_index}-{key}-assemble-independent") + independent_value = _assemble(independent, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-independent-" + f"{_tensor_summary(independent_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-independent") + diff_stats = diff_stats.merge( + _tensor_diff_value( + actual_value, + independent_value, + f"independent[{input_index}].{key}", + ), + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-independent-done") + _debug(f"reconstruct-input-{input_index}-{key}-done") + actual_tokens = _assemble(actual, "top_k_tokens", length) + oracle_tokens = _assemble(oracle, "top_k_tokens", length) + if actual_tokens is None or oracle_tokens is None: + if actual_tokens is not oracle_tokens: + raise AssertionError( + f"reconstructed[{input_index}].top_k None mismatch" + ) + elif not torch.equal(actual_tokens, oracle_tokens): + actual_logprobs = _assemble(actual, "top_k_logprobs", length) + oracle_logprobs = _assemble(oracle, "top_k_logprobs", length) + if ( + actual_logprobs is None + or oracle_logprobs is None + or _tensor_diff_value( + actual_logprobs, + oracle_logprobs, + f"reconstructed[{input_index}].top_k.logprobs", + ).max_abs_diff + > 5e-6 + ): + raise AssertionError( + f"reconstructed[{input_index}].top_k.tokens mismatch" + ) + if independent: + independent_tokens = _assemble(independent, "top_k_tokens", length) + if actual_tokens is None or independent_tokens is None: + if actual_tokens is not independent_tokens: + raise AssertionError( + f"independent[{input_index}].top_k None mismatch" + ) + elif not torch.equal(actual_tokens, independent_tokens): + actual_logprobs = _assemble(actual, "top_k_logprobs", length) + independent_logprobs = _assemble( + independent, + "top_k_logprobs", + length, + ) + if ( + actual_logprobs is None + or independent_logprobs is None + or _tensor_diff_value( + actual_logprobs, + independent_logprobs, + f"independent[{input_index}].top_k.logprobs", + ).max_abs_diff + > 5e-6 + ): + raise AssertionError( + f"independent[{input_index}].top_k.tokens mismatch" + ) + return diff_stats + + +def _assemble( + records: list[object], + key: str, + length: int, +) -> torch.Tensor | None: + typed_records = [record for record in records if isinstance(record, dict)] + values = [record[key] for record in typed_records if record[key] is not None] + if not values: + return None + first = values[0] + if not isinstance(first, torch.Tensor): + raise TypeError(key) + output = torch.empty((length, *first.shape[1:]), dtype=first.dtype) + filled = torch.zeros(length, dtype=torch.bool) + for record in typed_records: + value = record[key] + if value is None: + continue + if not isinstance(value, torch.Tensor): + raise TypeError(key) + positions = record["source_positions"] + if not isinstance(positions, torch.Tensor): + raise TypeError("source_positions") + output[positions] = value + filled[positions] = True + if not bool(filled.all().item()): + raise AssertionError(f"{key} reconstruction missed positions") + return output + + +def _tensor_summary(tensor: torch.Tensor | None) -> str: + if tensor is None: + return "None" + return f"shape={tuple(tensor.shape)} device={tensor.device} dtype={tensor.dtype}" + + +def _assert_close( + actual: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + expected: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + | CheckOutput, + label: str, +) -> DiffStats: + diffs = [ + _tensor_diff( + actual.target_logprobs, expected.target_logprobs, f"{label}.target_logprobs" + ) + ] + diffs.append(_tensor_diff(actual.logits, expected.logits, f"{label}.logits")) + diffs.append( + _tensor_diff( + actual.hidden_states, expected.hidden_states, f"{label}.hidden_states" + ) + ) + if actual.top_k is None or expected.top_k is None: + if actual.top_k is not expected.top_k: + raise AssertionError(f"{label}.top_k None mismatch") + else: + try: + top_k_diff = _tensor_diff( + actual.top_k.logprobs, + expected.top_k.logprobs, + f"{label}.top_k.logprobs", + ) + except AssertionError as exc: + flat_offset = int( + (actual.top_k.logprobs.float() - expected.top_k.logprobs.float()) + .abs() + .flatten() + .argmax() + ) + row, _ = divmod(flat_offset, int(actual.top_k.logprobs.shape[1])) + raise AssertionError( + f"{exc}; actual_row={actual.top_k.logprobs[row].tolist()} " + f"expected_row={expected.top_k.logprobs[row].tolist()} " + f"actual_tokens={actual.top_k.tokens[row].tolist()} " + f"expected_tokens={expected.top_k.tokens[row].tolist()}" + ) from exc + diffs.append(top_k_diff) + if ( + not torch.equal(actual.top_k.tokens, expected.top_k.tokens) + and top_k_diff.max_abs_diff > 5e-6 + ): + mismatch = torch.nonzero( + actual.top_k.tokens != expected.top_k.tokens, + as_tuple=False, + )[0] + row = int(mismatch[0].item()) + col = int(mismatch[1].item()) + raise AssertionError( + f"{label}.top_k.tokens mismatch at ({row}, {col}): " + f"actual={int(actual.top_k.tokens[row, col].item())} " + f"expected={int(expected.top_k.tokens[row, col].item())} " + f"actual_logprob={float(actual.top_k.logprobs[row, col].item())} " + f"expected_logprob={float(expected.top_k.logprobs[row, col].item())}" + ) + return _merge_diff_stats(diffs) + + +def _tensor_diff( + actual: torch.Tensor | None, + expected: torch.Tensor | None, + label: str, +) -> DiffStats: + return _tensor_diff_value(actual, expected, label) + + +def _tensor_diff_value( + actual: torch.Tensor | None, + expected: torch.Tensor | None, + label: str, +) -> DiffStats: + if actual is None or expected is None: + if actual is not expected: + raise AssertionError(f"{label} None mismatch") + return DiffStats() + if actual.shape != expected.shape: + raise AssertionError( + f"{label} shape mismatch: {actual.shape} != {expected.shape}" + ) + actual_for_diff = actual + expected_for_diff = expected + if torch.cuda.is_available(): + actual_for_diff = actual_for_diff.to(device="cuda") + expected_for_diff = expected_for_diff.to(device="cuda") + if actual_for_diff.numel(): + abs_diff = (actual_for_diff.float() - expected_for_diff.float()).abs() + max_abs_diff = float(abs_diff.max().item()) + denominator = float(expected_for_diff.float().abs().mean().item()) + mean_abs_pct = float(abs_diff.mean().item()) / (denominator + 1e-18) + else: + max_abs_diff = 0.0 + mean_abs_pct = 0.0 + tolerance = 5e-6 if "logprobs" in label else 0.0 + _debug( + f"{label} max_abs_diff={max_abs_diff} " + f"mean_abs_pct={mean_abs_pct} tolerance={tolerance}" + ) + if max_abs_diff > tolerance: + raise AssertionError(f"{label} max diff {max_abs_diff}") + return DiffStats(max_abs_diff=max_abs_diff, mean_abs_pct=mean_abs_pct) + + +def _merge_diff_stats(stats: list[DiffStats]) -> DiffStats: + merged = DiffStats() + for stat in stats: + merged = merged.merge(stat) + return merged + + +if __name__ == "__main__": + typer.run(main) diff --git a/scripts/build-gpu-image.sh b/scripts/build-gpu-image.sh index 299678584..dbce31484 100755 --- a/scripts/build-gpu-image.sh +++ b/scripts/build-gpu-image.sh @@ -10,10 +10,12 @@ Options: --image-repo REPO Image repository to publish --infra INFRA Kubernetes-backed SkyPilot infra (default: k8s/cks-wb3) --no-cache Disable registry-backed BuildKit cache + --no-prewarm-modal Skip prebuilding the pushed image in Modal --no-prewarm-nodes Skip pre-pulling the pushed image on GPU nodes --pull-image-repo REPO Image repository for cluster pulls/prewarm + --prewarm-modal Require prebuilding the pushed image in Modal --prewarm-timeout DUR Timeout for the prewarm DaemonSet rollout (default: 30m) - --tag TAG Image tag to publish + --tag TAG Image tag to publish (default: latest) --help Show this help EOF } @@ -24,12 +26,13 @@ cluster_name="" infra="${SKY_INFRA:-k8s/cks-wb3}" image_repo="${ART_IMAGE_REPO:-}" pull_image_repo="${ART_PULL_IMAGE_REPO:-}" -image_tag="" +image_tag="${IMAGE_TAG:-latest}" docker_config_path="${DOCKER_CONFIG_PATH:-${HOME}/.docker/config.json}" buildkit_image="${BUILDKIT_IMAGE:-moby/buildkit:v0.29.0-rootless}" buildkit_namespace="${KUBECTL_NAMESPACE:-default}" buildkit_wait_timeout="${BUILDKIT_WAIT_TIMEOUT:-300s}" no_cache="${NO_CACHE:-false}" +prewarm_modal="${PREWARM_MODAL:-auto}" prewarm_nodes="${PREWARM_NODES:-true}" prewarm_namespace="${PREWARM_NAMESPACE:-default}" prewarm_name="${PREWARM_NAME:-art-gpu-image-prewarm}" @@ -58,6 +61,10 @@ while [[ $# -gt 0 ]]; do no_cache=true shift ;; + --no-prewarm-modal) + prewarm_modal=false + shift + ;; --no-prewarm-nodes) prewarm_nodes=false shift @@ -66,6 +73,10 @@ while [[ $# -gt 0 ]]; do pull_image_repo="$2" shift 2 ;; + --prewarm-modal) + prewarm_modal=true + shift + ;; --prewarm-timeout) prewarm_timeout="$2" shift 2 @@ -86,6 +97,14 @@ while [[ $# -gt 0 ]]; do esac done +case "${prewarm_modal}" in + auto|true|false) ;; + *) + echo "PREWARM_MODAL must be one of: auto, true, false" >&2 + exit 1 + ;; +esac + case "${infra}" in k8s/*) kube_context="${infra#k8s/}" @@ -111,10 +130,6 @@ art_sha="$(git -C "${repo_root}" rev-parse HEAD)" art_short_sha="$(git -C "${repo_root}" rev-parse --short=12 HEAD)" timestamp="$(date +%m%d-%H%M%S)" -if [[ -z "${image_tag}" ]]; then - image_tag="skypilot-${art_short_sha}" -fi - if [[ -z "${cluster_name}" ]]; then cluster_name="art-gpu-build-${timestamp}" fi @@ -409,6 +424,38 @@ if [[ -n "${prewarm_refresh_tag_image}" ]]; then prewarm_display="${prewarm_image} and refreshing ${prewarm_refresh_tag_image}" fi +modal_auth_available=false +if [[ "${prewarm_modal}" != "false" ]]; then + if uv run --with 'modal>=1.5.0' python - <<'PY' >/dev/null 2>&1; then +import modal + +modal.Workspace.from_context().hydrate() +PY + modal_auth_available=true + fi +fi + +if [[ "${prewarm_modal}" == "true" || "${modal_auth_available}" == "true" ]]; then + echo "Prewarming ${image_repo}:${image_tag} in Modal image cache" + MODAL_FORCE_BUILD=1 uv run --with 'modal>=1.5.0' python - "${image_repo}:${image_tag}" <<'PY' +import sys + +import modal + +image = ( + modal.Image.from_registry(sys.argv[1], add_python="3.12") + .apt_install("openssh-server", "sudo", "rsync", "curl", "procps", "patch", "lsof") +) +app = modal.App.lookup("skypilot-modal", create_if_missing=True) +with modal.enable_output(): + image.build(app) +PY +elif [[ "${prewarm_modal}" == "auto" ]]; then + echo "Skipping Modal image prewarm: Modal auth unavailable" +else + echo "Skipping Modal image prewarm" +fi + dump_prewarm_diagnostics() { echo "::group::Prewarm diagnostics" "${kubectl_cmd[@]}" get daemonset -n "${prewarm_namespace}" "${prewarm_name}" -o wide || true diff --git a/src/art/megatron/__init__.py b/src/art/megatron/__init__.py index 3c2e5e5b9..a87296507 100644 --- a/src/art/megatron/__init__.py +++ b/src/art/megatron/__init__.py @@ -1,6 +1,15 @@ from typing import Any -__all__ = ["MegatronBackend"] +_TRAINER_RANK_EXPORTS = ( + "AdamParams", + "ForwardInput", + "ForwardOutput", + "MicroBatch", + "TopK", + "TrainerRank", +) + +__all__ = ["MegatronBackend", *_TRAINER_RANK_EXPORTS] def __getattr__(name: str) -> Any: @@ -8,4 +17,8 @@ def __getattr__(name: str) -> Any: from .backend import MegatronBackend return MegatronBackend + if name in _TRAINER_RANK_EXPORTS: + from . import trainer_rank + + return getattr(trainer_rank, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/art/megatron/context_parallel/block_mask.py b/src/art/megatron/context_parallel/block_mask.py index 91fe2023b..2d1be15f5 100644 --- a/src/art/megatron/context_parallel/block_mask.py +++ b/src/art/megatron/context_parallel/block_mask.py @@ -1,32 +1,47 @@ from __future__ import annotations +from dataclasses import dataclass + import numpy as np import torch from torch.nn.attention.flex_attention import BlockMask from art.megatron.flex_attn.compiled import normalize_sparse_block_size +from art.megatron.shared_prefix_tree import parse_shared_prefix_row from .types import AttnMaskKind, FlexMaskSpec -_INVALID_Q_GROUP = -(1 << 63) -_INVALID_Q_PARENT = _INVALID_Q_GROUP + 1 -_INVALID_K_GROUP = _INVALID_Q_GROUP + 2 +_INVALID_GROUP_INDEX = 0 + + +@dataclass(frozen=True, slots=True) +class PreparedBlockMaskContext: + group_ids: torch.Tensor + parent_ids: torch.Tensor + group_ids_np: np.ndarray + sorted_group_ids: np.ndarray + group_can_attend: np.ndarray + max_depth: int def _build_exact_mask_mod( *, q_abs: np.ndarray, k_abs: np.ndarray, - q_group: np.ndarray, - q_parent: np.ndarray, - k_group: np.ndarray, + q_group_index: np.ndarray, + k_group_index: np.ndarray, + group_can_attend: np.ndarray, device: torch.device, ): q_abs_tensor = torch.as_tensor(q_abs, device=device, dtype=torch.int64) k_abs_tensor = torch.as_tensor(k_abs, device=device, dtype=torch.int64) - q_group_tensor = torch.as_tensor(q_group, device=device, dtype=torch.int64) - q_parent_tensor = torch.as_tensor(q_parent, device=device, dtype=torch.int64) - k_group_tensor = torch.as_tensor(k_group, device=device, dtype=torch.int64) + q_group_tensor = torch.as_tensor(q_group_index, device=device, dtype=torch.int32) + k_group_tensor = torch.as_tensor(k_group_index, device=device, dtype=torch.int32) + group_can_attend_tensor = torch.as_tensor( + group_can_attend, + device=device, + dtype=torch.bool, + ) def mask_mod( batch_idx: torch.Tensor, @@ -37,9 +52,11 @@ def mask_mod( del batch_idx, head_idx q_abs_local = q_abs_tensor[query_idx] k_abs_local = k_abs_tensor[kv_idx] - same_group = q_group_tensor[query_idx] == k_group_tensor[kv_idx] - parent_prefix = q_parent_tensor[query_idx] == k_group_tensor[kv_idx] - return (q_abs_local >= k_abs_local) & (same_group | parent_prefix) + allowed_group = group_can_attend_tensor[ + q_group_tensor[query_idx], + k_group_tensor[kv_idx], + ] + return (q_abs_local >= k_abs_local) & allowed_group return mask_mod @@ -49,10 +66,18 @@ def _dense_blocks_to_ordered( *, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - counts = torch.from_numpy(blocks.sum(axis=-1).astype(np.int32)) - indices = torch.from_numpy( - np.argsort(-blocks.astype(np.int32), axis=-1, kind="stable").astype(np.int32) - ) + row_indices, column_indices = np.nonzero(blocks) + counts_np = np.bincount(row_indices, minlength=blocks.shape[0]).astype(np.int32) + indices_np = np.zeros(blocks.shape, dtype=np.int32) + if int(row_indices.size) > 0: + starts = np.concatenate(([0], np.cumsum(counts_np[:-1], dtype=np.int64))) + active_rows = np.flatnonzero(counts_np) + for row_index in active_rows: + start = int(starts[row_index]) + end = start + int(counts_np[row_index]) + indices_np[row_index, : end - start] = column_indices[start:end] + counts = torch.from_numpy(counts_np) + indices = torch.from_numpy(indices_np) return ( counts.view(1, 1, -1).to(device=device), indices.view(1, 1, blocks.shape[0], blocks.shape[1]).to(device=device), @@ -72,72 +97,129 @@ def _select_with_invalid_np( return selected -def _build_q_block_group_state( +def _is_strictly_increasing(values: np.ndarray) -> bool: + return int(values.size) <= 1 or bool(np.all(values[1:] > values[:-1])) + + +def _block_min_max( + values: np.ndarray, + starts: np.ndarray, + ends: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + mins = np.empty(starts.shape, dtype=values.dtype) + maxes = np.empty(starts.shape, dtype=values.dtype) + for index, (start, end) in enumerate(zip(starts, ends, strict=True)): + block = values[int(start) : int(end)] + mins[index] = block.min() + maxes[index] = block.max() + return mins, maxes + + +def _remap_group_values( + values: np.ndarray, *, - q_abs: np.ndarray, - q_group: np.ndarray, - q_parent: np.ndarray, - q_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], frozenset[int]]: - start = int(block_idx) * q_block - end = min((int(block_idx) + 1) * q_block, int(q_abs.size)) - q = q_abs[start:end] - q_group_block = q_group[start:end] - q_parent_block = q_parent[start:end] - q_min = int(q.min()) if int(q.size) else 0 - max_by_group: dict[int, int] = {} - all_groups: list[int] = [] - for group_value in np.unique(np.concatenate((q_group_block, q_parent_block))): - allowed = (q_group_block == group_value) | (q_parent_block == group_value) - if bool(allowed.any()): - max_by_group[int(group_value)] = int(q[allowed].max()) - if bool(allowed.all()): - all_groups.append(int(group_value)) - return q_min, max_by_group, frozenset(all_groups) - - -def _build_k_block_group_state( + sorted_group_ids: np.ndarray, +) -> np.ndarray: + remapped = np.full(values.shape, _INVALID_GROUP_INDEX, dtype=np.int32) + if int(sorted_group_ids.size) == 0: + return remapped + positions = np.searchsorted(sorted_group_ids, values) + in_bounds = positions < int(sorted_group_ids.size) + matched = np.zeros(values.shape, dtype=bool) + matched[in_bounds] = sorted_group_ids[positions[in_bounds]] == values[in_bounds] + remapped[matched] = positions[matched].astype(np.int32, copy=False) + 1 + return remapped + + +def _refine_exact_blocks( *, + partial_blocks: np.ndarray, + full_blocks: np.ndarray, + q_abs: np.ndarray, k_abs: np.ndarray, - k_group: np.ndarray, + q_group_index: np.ndarray, + k_group_index: np.ndarray, + group_can_attend: np.ndarray, + q_block: int, k_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], tuple[int, ...]]: - start = int(block_idx) * k_block - end = min((int(block_idx) + 1) * k_block, int(k_abs.size)) - k = k_abs[start:end] - k_group_block = k_group[start:end] - k_max = int(k.max()) if int(k.size) else 0 - min_by_group: dict[int, int] = {} - for group_value in np.unique(k_group_block): - min_by_group[int(group_value)] = int(k[k_group_block == group_value].min()) - return k_max, min_by_group, tuple(min_by_group) - - -def _exact_block_state( - *, - q_state: tuple[int, dict[int, int], frozenset[int]], - k_state: tuple[int, dict[int, int], tuple[int, ...]], -) -> tuple[bool, bool]: - q_min, q_allowed_max, q_all_allowed = q_state - k_max, k_min, k_groups = k_state - if not any( - q_allowed_max.get(k_group_value, _INVALID_Q_GROUP) >= min_k - for k_group_value, min_k in k_min.items() - ): - return False, False - if int(q_min) < int(k_max): - return True, False - return True, all(k_group_value in q_all_allowed for k_group_value in k_groups) + q_len: int, + k_len: int, + skip_uniform_allowed: bool, +) -> None: + candidate_blocks = partial_blocks | full_blocks + if skip_uniform_allowed: + q_starts = np.arange(candidate_blocks.shape[0], dtype=np.int64) * int(q_block) + k_starts = np.arange(candidate_blocks.shape[1], dtype=np.int64) * int(k_block) + q_ends = np.minimum(q_starts + int(q_block), int(q_len)) + k_ends = np.minimum(k_starts + int(k_block), int(k_len)) + q_group_min, q_group_max = _block_min_max(q_group_index, q_starts, q_ends) + k_group_min, k_group_max = _block_min_max(k_group_index, k_starts, k_ends) + q_block_indices, k_block_indices = np.nonzero(candidate_blocks) + homogeneous = (q_group_min[q_block_indices] == q_group_max[q_block_indices]) & ( + k_group_min[k_block_indices] == k_group_max[k_block_indices] + ) + if bool(np.any(homogeneous)): + homogeneous_q = q_block_indices[homogeneous] + homogeneous_k = k_block_indices[homogeneous] + allowed = group_can_attend[ + q_group_min[homogeneous_q], + k_group_min[homogeneous_k], + ] + disallowed_q = homogeneous_q[~allowed] + disallowed_k = homogeneous_k[~allowed] + partial_blocks[disallowed_q, disallowed_k] = False + full_blocks[disallowed_q, disallowed_k] = False + mixed_q = q_block_indices[~homogeneous] + mixed_k = k_block_indices[~homogeneous] + partial_blocks[mixed_q, mixed_k] = True + full_blocks[mixed_q, mixed_k] = False + return + + for q_block_index, k_block_index in np.argwhere(candidate_blocks): + q_start = int(q_block_index) * q_block + k_start = int(k_block_index) * k_block + q_end = q_start + q_block + k_end = k_start + k_block + if q_end > q_len or k_end > k_len: + continue + + q_slice = slice(q_start, q_end) + k_slice = slice(k_start, k_end) + if skip_uniform_allowed: + q_groups = np.unique(q_group_index[q_slice]) + k_groups = np.unique(k_group_index[k_slice]) + group_allowed = group_can_attend[np.ix_(q_groups, k_groups)] + if bool(np.all(group_allowed)): + continue + if not bool(np.any(group_allowed)): + partial_blocks[q_block_index, k_block_index] = False + full_blocks[q_block_index, k_block_index] = False + continue + partial_blocks[q_block_index, k_block_index] = True + full_blocks[q_block_index, k_block_index] = False + continue + can_attend = group_can_attend[ + q_group_index[q_slice, None], + k_group_index[None, k_slice], + ] + causal = q_abs[q_slice, None] >= k_abs[None, k_slice] + allowed = causal & can_attend + if not bool(np.any(allowed)): + partial_blocks[q_block_index, k_block_index] = False + full_blocks[q_block_index, k_block_index] = False + elif bool(np.all(allowed)): + partial_blocks[q_block_index, k_block_index] = False + full_blocks[q_block_index, k_block_index] = True + else: + partial_blocks[q_block_index, k_block_index] = True + full_blocks[q_block_index, k_block_index] = False def _build_sparse_block_mask( spec: FlexMaskSpec, *, device: torch.device, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + context: PreparedBlockMaskContext, block_size: tuple[int, int], ) -> BlockMask: q_block, k_block = block_size @@ -145,7 +227,6 @@ def _build_sparse_block_mask( k_blocks = (int(spec.k_len) + k_block - 1) // k_block partial_blocks = np.zeros((q_blocks, k_blocks), dtype=bool) full_blocks = np.zeros((q_blocks, k_blocks), dtype=bool) - touch_counts = np.zeros((q_blocks, k_blocks), dtype=np.int16) q_abs_tensor = spec.exact_mask.q_token_indices.detach().to( device="cpu", dtype=torch.int64, @@ -156,33 +237,32 @@ def _build_sparse_block_mask( ) q_abs = q_abs_tensor.numpy() k_abs = k_abs_tensor.numpy() - flat_group_ids = group_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) - flat_parent_ids = ( - parent_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) - ) - flat_group_ids_np = flat_group_ids.numpy() - flat_parent_ids_np = flat_parent_ids.numpy() + q_abs_sorted = _is_strictly_increasing(q_abs[q_abs >= 0]) + k_abs_sorted = _is_strictly_increasing(k_abs[k_abs >= 0]) q_group = _select_with_invalid_np( - flat_group_ids_np, - q_abs, - invalid_value=_INVALID_Q_GROUP, - ) - q_parent = _select_with_invalid_np( - flat_parent_ids_np, + context.group_ids_np, q_abs, - invalid_value=_INVALID_Q_PARENT, + invalid_value=-1, ) k_group = _select_with_invalid_np( - flat_group_ids_np, + context.group_ids_np, k_abs, - invalid_value=_INVALID_K_GROUP, + invalid_value=-1, + ) + q_group_index = _remap_group_values( + q_group, + sorted_group_ids=context.sorted_group_ids, + ) + k_group_index = _remap_group_values( + k_group, + sorted_group_ids=context.sorted_group_ids, ) mask_mod = _build_exact_mask_mod( q_abs=q_abs, k_abs=k_abs, - q_group=q_group, - q_parent=q_parent, - k_group=k_group, + q_group_index=q_group_index, + k_group_index=k_group_index, + group_can_attend=context.group_can_attend, device=device, ) if not spec.slices: @@ -208,15 +288,11 @@ def _build_sparse_block_mask( if int(q_block_indices.size) == 0 or int(k_block_indices.size) == 0: continue q_block_start = q_block_indices * q_block - q_block_end = np.minimum( - (q_block_indices + 1) * q_block, - int(spec.q_len), - ) + q_block_end_raw = (q_block_indices + 1) * q_block + q_block_end = np.minimum(q_block_end_raw, int(spec.q_len)) k_block_start = k_block_indices * k_block - k_block_end = np.minimum( - (k_block_indices + 1) * k_block, - int(spec.k_len), - ) + k_block_end_raw = (k_block_indices + 1) * k_block + k_block_end = np.minimum(k_block_end_raw, int(spec.k_len)) q_overlap_start = np.maximum( q_block_start, q_start, @@ -233,12 +309,22 @@ def _build_sparse_block_mask( k_block_end, k_end, ) - q_min = q_abs[q_overlap_start] - q_max = q_abs[q_overlap_end - 1] - k_min = k_abs[k_overlap_start] - k_max = k_abs[k_overlap_end - 1] - q_is_full = (q_overlap_start == q_block_start) & (q_overlap_end == q_block_end) - k_is_full = (k_overlap_start == k_block_start) & (k_overlap_end == k_block_end) + q_min, q_max = ( + (q_abs[q_overlap_start], q_abs[q_overlap_end - 1]) + if q_abs_sorted + else _block_min_max(q_abs, q_overlap_start, q_overlap_end) + ) + k_min, k_max = ( + (k_abs[k_overlap_start], k_abs[k_overlap_end - 1]) + if k_abs_sorted + else _block_min_max(k_abs, k_overlap_start, k_overlap_end) + ) + q_is_full = (q_overlap_start == q_block_start) & ( + q_overlap_end == q_block_end_raw + ) + k_is_full = (k_overlap_start == k_block_start) & ( + k_overlap_end == k_block_end_raw + ) covers_block = q_is_full[:, None] & k_is_full[None, :] if slice_.mask_kind == AttnMaskKind.FULL: has_any = np.ones( @@ -251,45 +337,25 @@ def _build_sparse_block_mask( q_slice = slice(int(q_block_indices[0]), int(q_block_indices[-1]) + 1) k_slice = slice(int(k_block_indices[0]), int(k_block_indices[-1]) + 1) - touch_counts[q_slice, k_slice] += has_any.astype(np.int16) partial_blocks[q_slice, k_slice] |= has_any full_blocks[q_slice, k_slice] |= is_full - ambiguous = (touch_counts > 1) & partial_blocks & ~full_blocks - q_state_cache: dict[int, tuple[int, dict[int, int], frozenset[int]]] = {} - k_state_cache: dict[int, tuple[int, dict[int, int], tuple[int, ...]]] = {} - for q_idx, k_idx in np.argwhere(ambiguous): - q_state = q_state_cache.get(int(q_idx)) - if q_state is None: - q_state = _build_q_block_group_state( - q_abs=q_abs, - q_group=q_group, - q_parent=q_parent, - q_block=q_block, - block_idx=int(q_idx), - ) - q_state_cache[int(q_idx)] = q_state - k_state = k_state_cache.get(int(k_idx)) - if k_state is None: - k_state = _build_k_block_group_state( - k_abs=k_abs, - k_group=k_group, - k_block=k_block, - block_idx=int(k_idx), - ) - k_state_cache[int(k_idx)] = k_state - has_any, is_full = _exact_block_state( - q_state=q_state, - k_state=k_state, - ) - partial_blocks[q_idx, k_idx] = False - full_blocks[q_idx, k_idx] = False - if is_full: - full_blocks[q_idx, k_idx] = True - elif has_any: - partial_blocks[q_idx, k_idx] = True - partial_blocks &= ~full_blocks + if int(context.group_can_attend.shape[0]) > 2: + _refine_exact_blocks( + partial_blocks=partial_blocks, + full_blocks=full_blocks, + q_abs=q_abs, + k_abs=k_abs, + q_group_index=q_group_index, + k_group_index=k_group_index, + group_can_attend=context.group_can_attend, + q_block=q_block, + k_block=k_block, + q_len=int(spec.q_len), + k_len=int(spec.k_len), + skip_uniform_allowed=context.max_depth <= 1, + ) kv_num_blocks, kv_indices = _dense_blocks_to_ordered( partial_blocks, device=device, @@ -321,6 +387,38 @@ def _build_sparse_block_mask( ) +def prepare_block_mask_context( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> PreparedBlockMaskContext: + if group_ids.ndim != 1 or parent_ids.ndim != 1: + raise RuntimeError( + "Shared-prefix sparse block masks require rank-1 group_ids and parent_ids." + ) + if int(group_ids.numel()) != int(parent_ids.numel()): + raise RuntimeError( + "Shared-prefix sparse block masks require equal group_ids and parent_ids lengths." + ) + flat_group_ids = group_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) + flat_parent_ids = ( + parent_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) + ) + row_tree = parse_shared_prefix_row( + group_ids=flat_group_ids, + parent_ids=flat_parent_ids, + ) + group_ids_for_matrix, group_can_attend_values = row_tree.group_can_attend_matrix() + return PreparedBlockMaskContext( + group_ids=flat_group_ids, + parent_ids=flat_parent_ids, + group_ids_np=flat_group_ids.numpy(), + sorted_group_ids=np.asarray(group_ids_for_matrix, dtype=np.int64), + group_can_attend=np.asarray(group_can_attend_values, dtype=bool), + max_depth=int(row_tree.max_depth), + ) + + def _valid_prefix(indices: torch.Tensor, *, name: str) -> torch.Tensor: if indices.ndim != 1: raise RuntimeError(f"{name} exact token indices must be rank 1.") @@ -347,9 +445,9 @@ def _validate_exact_indices( valid = _valid_prefix(indices, name=name) if int(valid.numel()) == 0: return 0 - if bool((valid[1:] <= valid[:-1]).any().item()): - raise RuntimeError(f"{name} exact token indices must be strictly increasing.") - max_index = int(valid[-1].item()) + if int(valid.unique().numel()) != int(valid.numel()): + raise RuntimeError(f"{name} exact token indices must not contain duplicates.") + max_index = int(valid.max().item()) if max_index >= int(source_len): raise RuntimeError( f"{name} exact token index {max_index} exceeds source metadata length {int(source_len)}." @@ -410,6 +508,23 @@ def build_block_mask( group_ids: torch.Tensor, parent_ids: torch.Tensor, device: torch.device, +) -> BlockMask | None: + return build_block_mask_from_context( + spec, + context=prepare_block_mask_context( + group_ids=group_ids, + parent_ids=parent_ids, + ), + device=device, + ) + + +def build_block_mask_from_context( + spec: FlexMaskSpec, + *, + context: PreparedBlockMaskContext, + device: torch.device, + validate: bool = True, ) -> BlockMask | None: if spec.q_len <= 0 or spec.k_len <= 0: return None @@ -423,12 +538,16 @@ def build_block_mask( "Exact stage k-token metadata length mismatch: " f"{int(spec.exact_mask.k_token_indices.numel())} != {int(spec.k_len)}" ) - _validate_supported_mask_spec(spec, group_ids=group_ids, parent_ids=parent_ids) + if validate: + _validate_supported_mask_spec( + spec, + group_ids=context.group_ids, + parent_ids=context.parent_ids, + ) block_size = normalize_sparse_block_size(spec.block_size) return _build_sparse_block_mask( spec, device=device, - group_ids=group_ids, - parent_ids=parent_ids, + context=context, block_size=block_size, ) diff --git a/src/art/megatron/context_parallel/builder.py b/src/art/megatron/context_parallel/builder.py index 77ac1b623..6b324d3f5 100644 --- a/src/art/megatron/context_parallel/builder.py +++ b/src/art/megatron/context_parallel/builder.py @@ -2,6 +2,8 @@ import torch +from art.megatron.shared_prefix_tree import parse_shared_prefix_tree + from .types import ( AttnMaskKind, AttnSlice, @@ -12,100 +14,6 @@ ) -def _valid_length( - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - *, - ignore_padding_group_id: int, -) -> int: - valid_mask = group_ids != ignore_padding_group_id - valid_count = int(valid_mask.sum().item()) - if valid_count == 0: - return 0 - if not bool(valid_mask[:valid_count].all().item()): - raise RuntimeError("Padding tokens must be a contiguous tail") - return _infer_terminal_padding_length( - group_ids[:valid_count], - parent_ids[:valid_count], - ) - - -def _infer_terminal_padding_length( - group_row: torch.Tensor, - parent_row: torch.Tensor, -) -> int: - if group_row.numel() == 0: - return 0 - runs = _scan_runs(group_row, parent_row) - if len(runs) < 2: - return int(group_row.numel()) - last_start, _last_end, last_group_id, last_parent_id = runs[-1] - if last_parent_id >= 0: - return int(group_row.numel()) - terminal_pair = (last_group_id, last_parent_id) - if any( - (group_id, parent_id) == terminal_pair - for _start, _end, group_id, parent_id in runs[:-1] - ): - return last_start - return int(group_row.numel()) - - -def _scan_runs( - group_row: torch.Tensor, - parent_row: torch.Tensor, -) -> list[tuple[int, int, int, int]]: - length = int(group_row.numel()) - if length == 0: - return [] - - group_changes = group_row[1:] != group_row[:-1] - parent_changes = parent_row[1:] != parent_row[:-1] - inconsistent_parent = torch.nonzero( - torch.logical_not(group_changes) & parent_changes, - as_tuple=False, - ).flatten() - if int(inconsistent_parent.numel()) > 0: - mismatch_index = int(inconsistent_parent[0].item()) + 1 - prior_boundaries = torch.nonzero( - group_changes[: mismatch_index - 1], - as_tuple=False, - ).flatten() - start = ( - 0 - if int(prior_boundaries.numel()) == 0 - else int(prior_boundaries[-1].item()) + 1 - ) - group_id = int(group_row[start].item()) - raise RuntimeError( - "Found one group run with inconsistent parent ids: " - f"group_id={group_id}, start={start}, end={mismatch_index}" - ) - - run_starts = torch.cat( - ( - torch.zeros(1, dtype=torch.int64, device=group_row.device), - torch.nonzero(group_changes, as_tuple=False).flatten() + 1, - ) - ) - run_ends = torch.cat( - ( - run_starts[1:], - torch.tensor([length], dtype=torch.int64, device=group_row.device), - ) - ) - starts = run_starts.to(device="cpu").tolist() - ends = run_ends.to(device="cpu").tolist() - group_ids = group_row.index_select(0, run_starts).to(device="cpu").tolist() - parent_ids = parent_row.index_select(0, run_starts).to(device="cpu").tolist() - return [ - (int(start), int(end), int(group_id), int(parent_id)) - for start, end, group_id, parent_id in zip( - starts, ends, group_ids, parent_ids, strict=True - ) - ] - - def _sort_and_dedupe_slices(slices: list[AttnSlice]) -> tuple[AttnSlice, ...]: sorted_slices = sorted( slices, @@ -138,18 +46,6 @@ def _sort_and_dedupe_slices(slices: list[AttnSlice]) -> tuple[AttnSlice, ...]: return tuple(deduped) -def _is_prompt_run( - *, - start: int, - group_id: int, - parent_id: int, - ignore_padding_group_id: int, -) -> bool: - return group_id == parent_id or ( - start == 0 and parent_id == ignore_padding_group_id - ) - - def build_shared_prefix_attention_spec( *, group_ids: torch.Tensor, @@ -166,127 +62,50 @@ def build_shared_prefix_attention_spec( "group_ids and parent_ids must be rank-2 packed tensors, got " f"{group_ids.ndim}" ) - if int(group_ids.shape[0]) != 1: - raise RuntimeError( - "ART shared-prefix attention spec currently supports exactly one packed sequence, " - f"got batch={int(group_ids.shape[0])}." - ) - rows: list[PackedRowAttentionSpec] = [] - for row_index in range(group_ids.shape[0]): - group_row = group_ids[row_index] - parent_row = parent_ids[row_index] - valid_tokens = _valid_length( - group_row, - parent_row, - ignore_padding_group_id=config.ignore_padding_group_id, - ) - if valid_tokens == 0: + for row in parse_shared_prefix_tree( + group_ids=group_ids, + parent_ids=parent_ids, + ignore_padding_group_id=config.ignore_padding_group_id, + require_contiguous_group_runs=config.require_contiguous_group_runs, + ): + if row.valid_tokens == 0: rows.append( - PackedRowAttentionSpec(row_index=row_index, valid_tokens=0, slices=()) + PackedRowAttentionSpec( + row_index=row.row_index, valid_tokens=0, slices=() + ) ) continue - group_row = group_row[:valid_tokens] - parent_row = parent_row[:valid_tokens] - runs = _scan_runs(group_row, parent_row) - - group_run_count: dict[int, int] = {} - prompt_by_group_id: dict[int, tuple[tuple[int, int], int]] = {} - completion_ranges_by_prompt: dict[int, list[tuple[int, int]]] = {} - - for start, end, group_id, parent_id in runs: - group_run_count[group_id] = group_run_count.get(group_id, 0) + 1 - if _is_prompt_run( - start=start, - group_id=group_id, - parent_id=parent_id, - ignore_padding_group_id=config.ignore_padding_group_id, - ): - if group_id in prompt_by_group_id: - raise RuntimeError( - f"Prompt group_id {group_id} appears more than once in row {row_index}" - ) - family_index = len(prompt_by_group_id) - prompt_by_group_id[group_id] = ( - (start, end), - family_index, - ) - completion_ranges_by_prompt[group_id] = [] - - if config.require_contiguous_group_runs: - repeated_groups = { - group_id: count - for group_id, count in group_run_count.items() - if count > 1 and group_id != config.ignore_padding_group_id - } - if repeated_groups: - raise RuntimeError( - "Shared-prefix builder requires contiguous group runs per row, " - f"found repeats in row {row_index}: {repeated_groups}" - ) - - for start, end, group_id, parent_id in runs: - if _is_prompt_run( - start=start, - group_id=group_id, - parent_id=parent_id, - ignore_padding_group_id=config.ignore_padding_group_id, - ): - continue - prompt_entry = prompt_by_group_id.get(parent_id) - if prompt_entry is None: - raise RuntimeError( - "Completion run points to a missing prompt run: " - f"row={row_index}, group_id={group_id}, parent_id={parent_id}" - ) - completion_ranges_by_prompt[parent_id].append((start, end)) - + segment_by_group_id = row.segment_by_group_id() row_slices: list[AttnSlice] = [] - for prompt_group_id, ( - (prompt_start, prompt_end), - family_index, - ) in prompt_by_group_id.items(): - prompt_range = TokenRange(start=prompt_start, end=prompt_end) - row_slices.append( - AttnSlice( - q_range=prompt_range, - k_range=prompt_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=row_index, - family_index=family_index, - ) - ) - for completion_start, completion_end in completion_ranges_by_prompt[ - prompt_group_id - ]: - completion_range = TokenRange( - start=completion_start, - end=completion_end, - ) + for segment in row.segments: + q_range = TokenRange(start=segment.start, end=segment.end) + for ancestor_group_id in segment.ancestors: + ancestor = segment_by_group_id[ancestor_group_id] row_slices.append( AttnSlice( - q_range=completion_range, - k_range=prompt_range, + q_range=q_range, + k_range=TokenRange(start=ancestor.start, end=ancestor.end), mask_kind=AttnMaskKind.FULL, - row_index=row_index, - family_index=family_index, + row_index=row.row_index, + family_index=segment.family_index, ) ) - row_slices.append( - AttnSlice( - q_range=completion_range, - k_range=completion_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=row_index, - family_index=family_index, - ) + row_slices.append( + AttnSlice( + q_range=q_range, + k_range=q_range, + mask_kind=AttnMaskKind.CAUSAL, + row_index=row.row_index, + family_index=segment.family_index, ) + ) rows.append( PackedRowAttentionSpec( - row_index=row_index, - valid_tokens=valid_tokens, + row_index=row.row_index, + valid_tokens=row.valid_tokens, slices=_sort_and_dedupe_slices(row_slices), ) ) diff --git a/src/art/megatron/context_parallel/executor.py b/src/art/megatron/context_parallel/executor.py index e5e219e72..3cb0779da 100644 --- a/src/art/megatron/context_parallel/executor.py +++ b/src/art/megatron/context_parallel/executor.py @@ -19,7 +19,7 @@ sparse_compiled_flex_attention, ) -from .block_mask import build_block_mask +from .block_mask import build_block_mask_from_context, prepare_block_mask_context from .comm import A2AVCommunicator from .range_ops import ( range_gather_head_major, @@ -684,7 +684,14 @@ def _build_stage_block_mask( raise RuntimeError( f"Stage {stage_plan.stage_index} is missing exact mask metadata" ) - mask = build_block_mask( + block_mask_context = state.execution_cache.block_mask_context + if block_mask_context is None: + block_mask_context = prepare_block_mask_context( + group_ids=state.group_ids, + parent_ids=state.parent_ids, + ) + state.execution_cache.block_mask_context = block_mask_context + mask = build_block_mask_from_context( FlexMaskSpec( q_len=int(execution_spec.q_len), k_len=int(execution_spec.k_len), @@ -692,9 +699,9 @@ def _build_stage_block_mask( slices=stage_plan.slices, exact_mask=mask_metadata.model_dump(mode="python"), ), - group_ids=state.group_ids, - parent_ids=state.parent_ids, + context=block_mask_context, device=device, + validate=False, ) cache[cache_key] = mask return mask diff --git a/src/art/megatron/context_parallel/runtime.py b/src/art/megatron/context_parallel/runtime.py index c6eb9fddd..f8888f0fd 100644 --- a/src/art/megatron/context_parallel/runtime.py +++ b/src/art/megatron/context_parallel/runtime.py @@ -2252,9 +2252,7 @@ def prepare_megatron_context_parallel_state( ) gdn_execution_spec = parse_gdn_shared_prefix_segments( - group_ids_cpu, - parent_ids_cpu, - min_completions_per_family=0, + group_ids_cpu, parent_ids_cpu, min_completions_per_family=0 ) bundle = _PlanningBundle( spec=spec, diff --git a/src/art/megatron/context_parallel/types.py b/src/art/megatron/context_parallel/types.py index 5cc874d09..2bc5eb657 100644 --- a/src/art/megatron/context_parallel/types.py +++ b/src/art/megatron/context_parallel/types.py @@ -119,7 +119,7 @@ class ContextParallelConfig(BaseModel): planner_remote_stage_underfill_ms: float = 0.287151 planner_tuned_backend: str | None = "art_context_parallel" planner_tuned_hardware: str | None = "NVIDIA H200" - planner_tuned_cp_sizes: tuple[int, ...] = (2,) + planner_tuned_cp_sizes: tuple[int, ...] = (2, 4) planner_cp_overrides: tuple[PlannerCpOverride, ...] = () @@ -223,6 +223,7 @@ class DispatchedPackedTensors(ContextParallelLossInputs): class ContextParallelExecutionCache(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) + block_mask_context: Any | None = None block_masks: dict[Any, Any] = Field(default_factory=dict) range_indices: dict[Any, torch.Tensor] = Field(default_factory=dict) range_meta: dict[Any, tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = Field( diff --git a/src/art/megatron/gdn/__init__.py b/src/art/megatron/gdn/__init__.py index cd3a0873a..1dc629403 100644 --- a/src/art/megatron/gdn/__init__.py +++ b/src/art/megatron/gdn/__init__.py @@ -3,12 +3,10 @@ from .fla_cp import chunk_gated_delta_rule_native_cp from .gdn_shared_prefix import ( GdnPackedExecutionSpec, - GdnPackedFamilySpec, GdnPlannerConfig, GdnRankExecutionPlan, GdnSegmentBucketPlan, GdnSegmentSpec, - build_gdn_cp_segment_schedule, build_gdn_rank_execution_plan, move_gdn_rank_execution_plan_to_device, parse_gdn_shared_prefix_segments, @@ -19,12 +17,10 @@ __all__ = [ "chunk_gated_delta_rule_native_cp", "GdnPackedExecutionSpec", - "GdnPackedFamilySpec", "GdnPlannerConfig", "GdnRankExecutionPlan", "GdnSegmentSpec", "GdnSegmentBucketPlan", - "build_gdn_cp_segment_schedule", "build_gdn_rank_execution_plan", "exchange_rank_tensor_all_to_all", "move_gdn_rank_execution_plan_to_device", diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py index 3fb693891..85e30fed2 100644 --- a/src/art/megatron/gdn/gdn_shared_prefix.py +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -7,9 +7,9 @@ import torch from art.megatron.context_parallel.layout_index import TokenLayoutIndex +from art.megatron.shared_prefix_tree import parse_shared_prefix_tree GdnSegmentKind = Literal["prefix", "completion"] -GdnSegmentDecisionKey = tuple[int, int, int] # FLA's public chunk_gated_delta_rule hard-codes 64-token WY chunks. FLA_CHUNK_SIZE = 64 _PydanticModelT = TypeVar("_PydanticModelT", bound=BaseModel) @@ -38,25 +38,6 @@ def linear_indices(self, sequence_length: int) -> tuple[int, ...]: return tuple(range(base + self.start, base + self.end)) -class GdnPackedFamilySpec(BaseModel): - """One shared-prefix family plus child completion segments.""" - - model_config = ConfigDict(frozen=True) - - row_index: int = Field(ge=0) - family_index: int = Field(ge=0) - prefix: GdnSegmentSpec - completions: tuple[GdnSegmentSpec, ...] - - @property - def completion_count(self) -> int: - return len(self.completions) - - @property - def token_count(self) -> int: - return self.prefix.length + sum(segment.length for segment in self.completions) - - class GdnPackedExecutionSpec(BaseModel): """Parsed shared-prefix GDN execution metadata for a packed batch.""" @@ -65,15 +46,17 @@ class GdnPackedExecutionSpec(BaseModel): batch_size: int = Field(ge=1) sequence_length: int = Field(ge=1) valid_lengths: tuple[int, ...] - families: tuple[GdnPackedFamilySpec, ...] + tree_segments: tuple[GdnSegmentSpec, ...] + tree_parent_indices: tuple[int, ...] + tree_depths: tuple[int, ...] @property def family_count(self) -> int: - return len(self.families) + return len(self.tree_segments) @property def completion_count(self) -> int: - return sum(family.completion_count for family in self.families) + return sum(1 for parent in self.tree_parent_indices if parent >= 0) @property def real_token_count(self) -> int: @@ -81,19 +64,10 @@ def real_token_count(self) -> int: @property def max_segment_length(self) -> int: - lengths = [ - segment.length - for family in self.families - for segment in (family.prefix, *family.completions) - ] - return max(lengths, default=0) + return max((segment.length for segment in self.tree_segments), default=0) def segments(self) -> tuple[GdnSegmentSpec, ...]: - return tuple( - segment - for family in self.families - for segment in (family.prefix, *family.completions) - ) + return self.tree_segments _GDN_SEGMENT_SPEC_FIELDS = frozenset( @@ -108,14 +82,6 @@ def segments(self) -> tuple[GdnSegmentSpec, ...]: "child_index", } ) -_GDN_PACKED_FAMILY_SPEC_FIELDS = frozenset( - { - "row_index", - "family_index", - "prefix", - "completions", - } -) def _trusted_pydantic_construct( @@ -146,6 +112,10 @@ class GdnSegmentBucketPlan(BaseModel): row_indices: torch.Tensor position_indices: torch.Tensor family_indices: torch.Tensor + family_indices_cpu: torch.Tensor | None = None + parent_indices: torch.Tensor | None = None + parent_indices_cpu: torch.Tensor | None = None + needs_final_state: bool = True real_token_count_static: int = Field(ge=0) output_mask: torch.Tensor | None = None @@ -158,15 +128,15 @@ def real_token_count(self) -> int: return self.real_token_count_static -class GdnParentStateTransferPlan(BaseModel): - """Prefix-state rows transferred from one CP rank to another.""" +class GdnStateExchangePlan(BaseModel): + """Sparse CP exchange for tree parent states needed by remote children.""" model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - source_rank: int = Field(ge=0) - dest_rank: int = Field(ge=0) - family_indices: tuple[int, ...] - family_indices_tensor: torch.Tensor | None = None + source_family_indices: tuple[int, ...] + dest_family_indices: tuple[int, ...] + exchange: Any + reverse_exchange: Any class GdnPlannerConfig(BaseModel): @@ -179,28 +149,12 @@ class GdnPlannerConfig(BaseModel): cp_chain_min_tokens_per_rank: int = Field(default=32, ge=1) cp_chain_min_total_tokens: int = Field(default=32768, ge=1) cp_chain_min_prefix_only_tokens: int = Field(default=32768, ge=1) - local_fork_launch_penalty_tokens: int = Field(default=256, ge=0) - cp_collective_latency_tokens: int = Field(default=512, ge=0) - parent_state_exchange_penalty_tokens: int = Field(default=16384, ge=0) - layout_cross_rank_token_cost: float = Field(default=6.0, ge=0.0) + cp_tree_chain_min_total_tokens: int = Field(default=8192, ge=1) + cp_tree_chain_min_prefix_only_tokens: int = Field(default=8192, ge=1) rank_idle_token_cost: float = Field(default=1.0, ge=0.0) - empty_rank_penalty_tokens: int = Field(default=65536, ge=0) max_zero_exchange_load_imbalance: float = Field(default=1.5, ge=1.0) - local_completion_rebalance_min_imbalance: float = Field(default=1.08, ge=1.0) - cp_chain_beam_width: int = Field(default=2, ge=1) - cp_chain_beam_branch_factor: int = Field(default=4, ge=1) - cp_chain_beam_candidate_limit: int = Field(default=16, ge=1) - cp_chain_beam_max_steps: int = Field(default=4, ge=0) - cp_chain_beam_min_score_delta_tokens: float = Field(default=512.0, ge=0.0) - cp_chain_min_score_delta_ms: float = Field(default=0.25, ge=0.0) planner_local_token_ms: float = Field(default=0.00065, ge=0.0) - planner_chain_token_ms: float = Field(default=0.00055, ge=0.0) - planner_local_bucket_ms: float = Field(default=0.25, ge=0.0) - planner_chain_bucket_ms: float = Field(default=22.0, ge=0.0) - planner_local_segment_ms: float = Field(default=0.010, ge=0.0) planner_layout_cross_rank_token_ms: float = Field(default=0.00008, ge=0.0) - planner_parent_state_exchange_base_ms: float = Field(default=40.0, ge=0.0) - planner_parent_state_exchange_ms: float = Field(default=0.5, ge=0.0) planner_empty_rank_ms: float = Field(default=32.0, ge=0.0) @@ -218,29 +172,15 @@ class GdnRankExecutionPlan(BaseModel): real_token_mask: torch.Tensor family_count: int = Field(ge=0) completion_count: int = Field(ge=0) - local_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () - local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - ready_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - chain_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () - chain_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - prefix_table_is_dense_ordered: bool attention_to_gdn: Any | None = None gdn_to_attention: Any | None = None attention_token_ranges: tuple[tuple[int, int, int], ...] = () gdn_token_ranges: tuple[tuple[int, int, int], ...] = () attention_token_count: int = Field(default=0, ge=0) gdn_token_count: int = Field(default=0, ge=0) - parent_state_exchange_family_indices: tuple[int, ...] = () - parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () - prefix_boundary_buckets: tuple[GdnSegmentBucketPlan, ...] = () - prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - completion_with_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_completion_with_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_prefix_tail_exchange: Any | None = None - remote_prefix_tail_backward_exchange: Any | None = None - remote_prefix_tail_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () + tree_segment_buckets_by_depth: tuple[tuple[GdnSegmentBucketPlan, ...], ...] = () + tree_chain_buckets_by_depth: tuple[tuple[GdnSegmentBucketPlan, ...], ...] = () + tree_state_exchanges_by_depth: tuple[GdnStateExchangePlan | None, ...] = () @property def attention_token_indices(self) -> tuple[int, ...]: @@ -251,58 +191,6 @@ def gdn_token_indices(self) -> tuple[int, ...]: return _tokens_from_rank_ranges(self.gdn_token_ranges) -class GdnCpSegmentSchedule(BaseModel): - """CPU-side ownership and bucket schedule for one CP GDN plan.""" - - model_config = ConfigDict(frozen=True) - - gdn_token_counts_by_rank: tuple[int, ...] - gdn_token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] = () - cross_rank_token_count: int = Field(ge=0) - chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] - chain_completion_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] - local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] - local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] - parent_state_exchange_family_indices: tuple[int, ...] = () - parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () - - -class _GdnCpSegmentSearchDecision(BaseModel): - model_config = ConfigDict(frozen=True) - - chain_segment_keys: frozenset[GdnSegmentDecisionKey] - co_locate_local_families: bool - score: float - - -class _ExplicitBucketColumn(BaseModel): - model_config = ConfigDict(frozen=True) - - row_index: int - family_index: int - positions: tuple[int, ...] - output_mask: tuple[bool, ...] - - @property - def length(self) -> int: - return len(self.positions) - - -def _explicit_bucket_column( - *, - row_index: int, - family_index: int, - positions: tuple[int, ...], - output_mask: tuple[bool, ...], -) -> _ExplicitBucketColumn: - return _ExplicitBucketColumn.model_construct( - row_index=row_index, - family_index=family_index, - positions=positions, - output_mask=output_mask, - ) - - class _AttentionLayoutIndex(BaseModel): """Counting index for CP attention token ownership.""" @@ -349,7 +237,6 @@ def build_gdn_rank_execution_plan( cp_rank: int = 0, cp_size: int = 1, attention_token_layout_index: TokenLayoutIndex | None = None, - cp_segment_schedule: GdnCpSegmentSchedule | None = None, planner_config: GdnPlannerConfig | None = None, ) -> GdnRankExecutionPlan: """Build rank-local tensor metadata from a parsed shared-prefix DAG. @@ -368,67 +255,226 @@ def build_gdn_rank_execution_plan( cp_rank=cp_rank, cp_size=cp_size, attention_token_layout_index=attention_token_layout_index, - cp_segment_schedule=cp_segment_schedule, planner_config=planner_config, ) return move_gdn_rank_execution_plan_to_device(cpu_plan, target_device) - if cp_size != 1 or cp_rank != 0: - return _build_cp_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - cp_segment_schedule=cp_segment_schedule, - planner_config=planner_config, - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_cp1_bucket_plans( + return _build_tree_rank_execution_plan( spec, device=device, + cp_rank=cp_rank, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + + +def _build_tree_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, + planner_config: GdnPlannerConfig, +) -> GdnRankExecutionPlan: + if cp_size < 1: + raise ValueError(f"cp_size must be >= 1, got {cp_size}") + if cp_rank < 0 or cp_rank >= cp_size: + raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") + if not spec.tree_segments: + raise ValueError("tree GDN planning requires tree segments") + if len(spec.tree_parent_indices) != len(spec.tree_segments): + raise ValueError("tree parent metadata length must match tree segments") + if len(spec.tree_depths) != len(spec.tree_segments): + raise ValueError("tree depth metadata length must match tree segments") + + from art.megatron.gdn.layout import ( + _reverse_exchange_plan, + build_local_rank_cp_exchange_plan_from_dest_ranges, + ) + + source_layout = _attention_source_layout( + spec, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, planner_config=planner_config, ) - valid_lengths = torch.tensor( - spec.valid_lengths, + attention_layout_index = _build_attention_layout_index_from_token_layout( + source_layout, + max_ranges=max(1, 2 * spec.real_token_count // len(spec.tree_segments)), + ) + segment_attention_counts = _segment_attention_rank_counts( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + ) + + depth_count = max(spec.tree_depths, default=0) + 1 + rank_loads = [0] * cp_size + owner_by_node = [-1] * len(spec.tree_segments) + chained_nodes = [False] * len(spec.tree_segments) + tree_has_children = [False] * len(spec.tree_segments) + for parent_index in spec.tree_parent_indices: + if parent_index >= 0: + tree_has_children[parent_index] = True + gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + segments_by_rank_depth: list[list[list[GdnSegmentSpec]]] = [ + [[] for _ in range(depth_count)] for _ in range(cp_size) + ] + chain_segments_by_depth: list[list[GdnSegmentSpec]] = [ + [] for _ in range(depth_count) + ] + cross_rank_token_count = 0 + + tree_segments_by_depth: list[list[GdnSegmentSpec]] = [ + [] for _ in range(depth_count) + ] + for segment in spec.tree_segments: + tree_segments_by_depth[spec.tree_depths[segment.family_index]].append(segment) + + for depth, depth_segments in enumerate(tree_segments_by_depth): + local_groups: list[tuple[GdnSegmentSpec, ...]] = [] + for segment in depth_segments: + parent_index = spec.tree_parent_indices[segment.family_index] + if ( + parent_index < 0 + and cp_size > 1 + and _can_chain_tree_segment( + segment, + cp_size=cp_size, + planner_config=planner_config, + ) + ): + chained_nodes[segment.family_index] = True + chain_segments_by_depth[depth].append(segment) + cross_rank_token_count += _append_chain_segment( + gdn_ranges_by_rank, + rank_loads, + segment, + spec, + attention_layout_index=attention_layout_index, + ) + continue + local_groups.append((segment,)) + + for local_group in local_groups: + owner = _best_segment_owner( + local_group, + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + for segment in local_group: + owner_by_node[segment.family_index] = owner + segments_by_rank_depth[owner][depth].append(segment) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, + rank_loads, + owner, + segment, + spec, + segment_attention_counts=segment_attention_counts, + ) + + gdn_ranges_by_rank_by_position = tuple( + tuple(ranges) for ranges in gdn_ranges_by_rank + ) + gdn_ranges_by_rank_by_source = tuple( + tuple(sorted(ranges)) for ranges in gdn_ranges_by_rank + ) + + attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( + source_layout=source_layout, device=device, - dtype=torch.long, + local_rank=cp_rank, + dest_ranges_by_rank=gdn_ranges_by_rank_by_position, + cross_rank_token_count=cross_rank_token_count, + ) + local_token_ranges = gdn_ranges_by_rank_by_source[cp_rank] + tree_segment_buckets_by_depth = tuple( + ( + _build_tree_segment_bucket_plans( + tuple(segments_by_rank_depth[cp_rank][depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + device=device, + planner_config=planner_config, + ) + if cp_size == 1 + else _build_tree_position_bucket_plans( + tuple(segments_by_rank_depth[cp_rank][depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ) + ) + for depth in range(depth_count) + ) + tree_chain_buckets_by_depth = ( + tuple( + _build_tree_position_bucket_plans( + tuple(chain_segments_by_depth[depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + token_ranges_by_rank=tuple( + tuple(ranges) for ranges in gdn_ranges_by_rank_by_source + ), + split_by_final_state=False, + ) + for depth in range(depth_count) + ) + if cp_size > 1 + else tuple(() for _ in range(depth_count)) ) - positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) - local_range_list: list[tuple[int, int, int]] = [] - local_position = 0 - for row_index, length in enumerate(spec.valid_lengths): - if length: - start = row_index * spec.sequence_length - local_range_list.append((start, start + length, local_position)) - local_position += length - local_ranges = tuple(local_range_list) + tree_state_exchanges_by_depth = _build_tree_state_exchanges_by_depth( + spec, + owner_by_node=tuple(owner_by_node), + chained_nodes=tuple(chained_nodes), + cp_rank=cp_rank, + cp_size=cp_size, + depth_count=depth_count, + device=device, + ) + if cp_size == 1: + valid_lengths = torch.tensor( + spec.valid_lengths, device=device, dtype=torch.long + ) + positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) + real_token_mask = positions.unsqueeze(0) < valid_lengths.unsqueeze(1) + else: + real_token_mask = torch.ones( + 1, + rank_loads[cp_rank], + device=device, + dtype=torch.bool, + ) + return GdnRankExecutionPlan.model_construct( cp_rank=cp_rank, cp_size=cp_size, - batch_size=spec.batch_size, - sequence_length=spec.sequence_length, + batch_size=1 if cp_size > 1 else spec.batch_size, + sequence_length=rank_loads[cp_rank] if cp_size > 1 else spec.sequence_length, packed_batch_size=spec.batch_size, packed_sequence_length=spec.sequence_length, - real_token_mask=positions.unsqueeze(0) < valid_lengths.unsqueeze(1), + real_token_mask=real_token_mask, family_count=spec.family_count, completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=(), - ready_local_completion_buckets=(), - remote_local_completion_buckets=(), - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=False, - attention_token_ranges=local_ranges, - gdn_token_ranges=local_ranges, - attention_token_count=spec.real_token_count, - gdn_token_count=spec.real_token_count, - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, + attention_to_gdn=attention_to_gdn, + gdn_to_attention=_reverse_exchange_plan(attention_to_gdn), + attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], + gdn_token_ranges=gdn_ranges_by_rank_by_position[cp_rank], + attention_token_count=source_layout.token_counts_by_rank[cp_rank], + gdn_token_count=rank_loads[cp_rank], + tree_segment_buckets_by_depth=tree_segment_buckets_by_depth, + tree_chain_buckets_by_depth=tree_chain_buckets_by_depth, + tree_state_exchanges_by_depth=tree_state_exchanges_by_depth, ) @@ -450,52 +496,41 @@ def move_gdn_rank_execution_plan_to_device( real_token_mask=_move_planner_tensor(plan.real_token_mask, device), family_count=plan.family_count, completion_count=plan.completion_count, - local_prefix_buckets=_move_bucket_plans(plan.local_prefix_buckets, device), - local_completion_buckets=_move_bucket_plans( - plan.local_completion_buckets, device - ), - ready_local_completion_buckets=_move_bucket_plans( - plan.ready_local_completion_buckets, device - ), - remote_local_completion_buckets=_move_bucket_plans( - plan.remote_local_completion_buckets, device - ), - chain_prefix_buckets=_move_bucket_plans(plan.chain_prefix_buckets, device), - chain_completion_buckets=_move_bucket_plans( - plan.chain_completion_buckets, device - ), - prefix_table_is_dense_ordered=plan.prefix_table_is_dense_ordered, attention_to_gdn=move_cp_exchange_plan_to_device(plan.attention_to_gdn, device), gdn_to_attention=move_cp_exchange_plan_to_device(plan.gdn_to_attention, device), attention_token_ranges=plan.attention_token_ranges, gdn_token_ranges=plan.gdn_token_ranges, attention_token_count=plan.attention_token_count, gdn_token_count=plan.gdn_token_count, - parent_state_exchange_family_indices=plan.parent_state_exchange_family_indices, - parent_state_transfers=_move_parent_state_transfers( - plan.parent_state_transfers, device - ), - prefix_boundary_buckets=_move_bucket_plans( - plan.prefix_boundary_buckets, device - ), - prefix_tail_buckets=_move_bucket_plans(plan.prefix_tail_buckets, device), - completion_with_prefix_tail_buckets=_move_bucket_plans( - plan.completion_with_prefix_tail_buckets, device - ), - remote_prefix_tail_buckets=_move_bucket_plans( - plan.remote_prefix_tail_buckets, device + tree_segment_buckets_by_depth=tuple( + _move_bucket_plans(buckets, device) + for buckets in plan.tree_segment_buckets_by_depth ), - remote_completion_with_prefix_tail_buckets=_move_bucket_plans( - plan.remote_completion_with_prefix_tail_buckets, device + tree_chain_buckets_by_depth=tuple( + _move_bucket_plans(buckets, device) + for buckets in plan.tree_chain_buckets_by_depth ), - remote_prefix_tail_exchange=move_cp_exchange_plan_to_device( - plan.remote_prefix_tail_exchange, device + tree_state_exchanges_by_depth=tuple( + _move_state_exchange_plan(exchange, device) + for exchange in plan.tree_state_exchanges_by_depth ), - remote_prefix_tail_backward_exchange=move_cp_exchange_plan_to_device( - plan.remote_prefix_tail_backward_exchange, device - ), - remote_prefix_tail_state_transfers=_move_parent_state_transfers( - plan.remote_prefix_tail_state_transfers, device + ) + + +def _move_state_exchange_plan( + exchange: GdnStateExchangePlan | None, + device: torch.device | str, +) -> GdnStateExchangePlan | None: + if exchange is None: + return None + from art.megatron.gdn.layout import move_cp_exchange_plan_to_device + + return GdnStateExchangePlan.model_construct( + source_family_indices=exchange.source_family_indices, + dest_family_indices=exchange.dest_family_indices, + exchange=move_cp_exchange_plan_to_device(exchange.exchange, device), + reverse_exchange=move_cp_exchange_plan_to_device( + exchange.reverse_exchange, device ), ) @@ -516,6 +551,14 @@ def _move_bucket_plans( row_indices=_move_planner_tensor(bucket.row_indices, device), position_indices=_move_planner_tensor(bucket.position_indices, device), family_indices=_move_planner_tensor(bucket.family_indices, device), + family_indices_cpu=bucket.family_indices_cpu, + parent_indices=( + _move_planner_tensor(bucket.parent_indices, device) + if bucket.parent_indices is not None + else None + ), + parent_indices_cpu=bucket.parent_indices_cpu, + needs_final_state=bucket.needs_final_state, real_token_count_static=bucket.real_token_count, output_mask=( _move_planner_tensor(bucket.output_mask, device) @@ -527,2644 +570,111 @@ def _move_bucket_plans( ) -def _move_parent_state_transfers( - transfers: tuple[GdnParentStateTransferPlan, ...], +def parse_gdn_shared_prefix_segments( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + *, + min_completions_per_family: int = 0, +) -> GdnPackedExecutionSpec: + """Parse ART packed shared-prefix metadata into generic GDN tree nodes.""" + + del min_completions_per_family + groups = _rank2_long_cpu("group_ids", group_ids) + parents = _rank2_long_cpu("parent_ids", parent_ids) + if tuple(groups.shape) != tuple(parents.shape): + raise ValueError( + "group_ids and parent_ids must have the same shape, got " + f"{tuple(groups.shape)} and {tuple(parents.shape)}" + ) + + batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) + rows = parse_shared_prefix_tree(group_ids=groups, parent_ids=parents) + tree_segments: list[GdnSegmentSpec] = [] + tree_parent_indices: list[int] = [] + tree_depths: list[int] = [] + valid_lengths: list[int] = [] + node_by_row_group: dict[tuple[int, int], int] = {} + child_counts_by_parent: dict[int, int] = {} + + for row in rows: + valid_lengths.append(row.valid_tokens) + for segment in row.segments: + node_index = len(tree_segments) + is_root = segment.depth == 0 + parent_node_index = ( + -1 + if is_root + else node_by_row_group[(segment.row_index, segment.parent_id)] + ) + child_index = None + if not is_root: + child_index = child_counts_by_parent.get(parent_node_index, 0) + child_counts_by_parent[parent_node_index] = child_index + 1 + tree_segments.append( + _trusted_pydantic_construct( + GdnSegmentSpec, + _GDN_SEGMENT_SPEC_FIELDS, + row_index=segment.row_index, + family_index=node_index, + group_id=segment.group_id, + parent_id=segment.parent_id, + start=segment.start, + end=segment.end, + kind="prefix" if is_root else "completion", + child_index=child_index, + ) + ) + tree_parent_indices.append(parent_node_index) + tree_depths.append(segment.depth) + node_by_row_group[(segment.row_index, segment.group_id)] = node_index + + return GdnPackedExecutionSpec( + batch_size=batch_size, + sequence_length=sequence_length, + valid_lengths=tuple(valid_lengths), + tree_segments=tuple(tree_segments), + tree_parent_indices=tuple(tree_parent_indices), + tree_depths=tuple(tree_depths), + ) + + +def _build_segment_bucket_plans( + segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], + *, device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: +) -> tuple[GdnSegmentBucketPlan, ...]: return tuple( - GdnParentStateTransferPlan.model_construct( - source_rank=transfer.source_rank, - dest_rank=transfer.dest_rank, - family_indices=transfer.family_indices, - family_indices_tensor=( - _move_planner_tensor(transfer.family_indices_tensor, device) - if transfer.family_indices_tensor is not None - else None - ), - ) - for transfer in transfers + _build_segment_bucket_plan(bucket[0].length, bucket, device=device) + for bucket in segment_buckets ) -def _build_local_attention_layout_rank_execution_plan( +def _attention_source_layout( spec: GdnPackedExecutionSpec, *, - device: torch.device | str, - cp_rank: int, cp_size: int, attention_token_layout_index: TokenLayoutIndex | None, planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan | None: - if cp_size <= 1 or not spec.families: - return None - if any( - _has_chainable_segment(family, cp_size=cp_size, planner_config=planner_config) - for family in spec.families - ): - return None - - from art.megatron.gdn.layout import ( - _reverse_exchange_plan, - build_local_rank_cp_exchange_plan_from_dest_ranges, - ) - - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - attention_layout_index = _build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max(1, 2 * spec.real_token_count // len(tuple(spec.segments()))), - ) - segment_attention_counts = _segment_attention_rank_counts( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - best = _assign_local_attention_segments( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - co_locate_local_families=False, - planner_config=planner_config, - ) - co_located = _assign_local_attention_segments( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - co_locate_local_families=True, - planner_config=planner_config, - ) - if co_located[4] < best[4]: - best = co_located - ( - prefix_owner_by_family, - completion_owners_by_family, - _, - cross_rank_token_count, - _, - ) = best - - local_prefix_segments: list[GdnSegmentSpec] = [] - local_completion_segments: list[GdnSegmentSpec] = [] - prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [[] for _ in range(cp_size)] - completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] - rank_loads = [0] * cp_size - parent_state_exchange_families: set[int] = set() - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - - def append_segment(rank: int, segment: GdnSegmentSpec) -> None: - token_start = _segment_token_start(segment, spec.sequence_length) - position_start = rank_loads[rank] - gdn_ranges_by_rank[rank].append( - (token_start, token_start + segment.length, position_start) - ) - rank_loads[rank] += segment.length - - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - if prefix_owner == cp_rank: - local_prefix_segments.append(family.prefix) - prefix_segments_by_rank[prefix_owner].append(family.prefix) - append_segment(prefix_owner, family.prefix) - completion_owners = completion_owners_by_family[family.family_index] - for completion, completion_owner in zip( - family.completions, completion_owners, strict=True - ): - if completion_owner == cp_rank: - local_completion_segments.append(completion) - completion_segments_by_rank[completion_owner].append(completion) - append_segment(completion_owner, completion) - if completion_owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - parent_state_transfer_families.setdefault( - (prefix_owner, completion_owner), set() - ).add(family.family_index) - - local_token_ranges = tuple(gdn_ranges_by_rank[cp_rank]) - local_token_count = rank_loads[cp_rank] - schedule = GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=tuple(rank_loads), - gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - cross_rank_token_count=cross_rank_token_count, - chain_prefix_buckets=(), - chain_completion_buckets=(), - local_prefix_segments_by_rank=tuple( - tuple(segments) for segments in prefix_segments_by_rank - ), - local_completion_segments_by_rank=tuple( - tuple(segments) for segments in completion_segments_by_rank - ), - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families) - ), - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - if parent_state_transfer_families: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( - source_layout=source_layout, - device=device, - dest_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - local_rank=cp_rank, - cross_rank_token_count=cross_rank_token_count, - ) - gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) - local_prefix_family_indices = { - segment.family_index for segment in local_prefix_segments - } - local_prefix_buckets = _batch_segments_by_padded_work( - (), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - chunk_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index in local_prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - plain_local_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=(), - ) - ) - ready_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - prefix_family_order = tuple( - segment.family_index for bucket in local_prefix_buckets for segment in bucket - ) - ready_completion_bucket_plans = _build_position_bucket_plans( - ready_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - remote_completion_bucket_plans = _build_position_bucket_plans( - remote_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - tuple(local_prefix_segments), - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=_build_position_bucket_plans( - local_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - local_completion_buckets=( - ready_completion_bucket_plans + remote_completion_bucket_plans - ), - ready_local_completion_buckets=ready_completion_bucket_plans, - remote_local_completion_buckets=remote_completion_bucket_plans, - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=( - not local_prefix_segments - and prefix_family_order == tuple(range(spec.family_count)) - ), - attention_to_gdn=attention_to_gdn, - gdn_to_attention=gdn_to_attention, - attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], - gdn_token_ranges=local_token_ranges, - attention_token_count=source_layout.token_counts_by_rank[cp_rank], - gdn_token_count=local_token_count, - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families - remote_prefix_tail_families) - ), - parent_state_transfers=_filter_parent_state_transfers( - _build_parent_state_transfer_plans(parent_state_transfer_families), - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - - -def _assign_local_attention_segments( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[int, ...], - tuple[tuple[int, ...], ...], - tuple[int, ...], - int, - float, -]: - rank_loads = [0] * cp_size - has_prefix = [False] * cp_size - has_completion = [False] * cp_size - prefix_owner_by_family: list[int] = [] - completion_owners_by_family: list[tuple[int, ...]] = [] - parent_state_exchange_families: set[int] = set() - cross_rank_token_count = 0 - - def append_owner(rank: int, segment: GdnSegmentSpec) -> None: - nonlocal cross_rank_token_count - rank_loads[rank] += segment.length - cross_rank_token_count += ( - segment.length - segment_attention_counts[_segment_key(segment)][rank] - ) - - for family in spec.families: - if co_locate_local_families: - owner = _best_segment_owner( - (family.prefix, *family.completions), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - prefix_owner_by_family.append(owner) - completion_owners = tuple(owner for _ in family.completions) - completion_owners_by_family.append(completion_owners) - has_prefix[owner] = True - for segment in (family.prefix, *family.completions): - append_owner(owner, segment) - if family.completions: - has_completion[owner] = True - continue - - prefix_owner = _best_segment_owner( - (family.prefix,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - prefix_owner_by_family.append(prefix_owner) - has_prefix[prefix_owner] = True - append_owner(prefix_owner, family.prefix) - completion_owners = [] - for completion in family.completions: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - completion_owners.append(owner) - has_completion[owner] = True - append_owner(owner, completion) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - completion_owners_by_family.append(tuple(completion_owners)) - - del has_prefix, has_completion - score = _score_local_segment_assignment( - spec, - cp_size=cp_size, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - rank_loads=tuple(rank_loads), - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=len(parent_state_exchange_families), - planner_config=planner_config, - ) - return ( - tuple(prefix_owner_by_family), - tuple(completion_owners_by_family), - tuple(sorted(parent_state_exchange_families)), - cross_rank_token_count, - score, - ) - - -def _score_local_segment_assignment( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], - rank_loads: tuple[int, ...], - cross_rank_token_count: int, - parent_state_exchange_family_count: int, - planner_config: GdnPlannerConfig, -) -> float: - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - completion_owners = completion_owners_by_family[family.family_index] - for completion, completion_owner in zip( - family.completions, completion_owners, strict=True - ): - local_completion_segments_by_rank[completion_owner].append(completion) - ( - local_work_by_rank, - local_bucket_count, - local_segment_count, - ) = _estimate_local_rank_kernel_work( - tuple(tuple(segments) for segments in local_prefix_segments_by_rank), - tuple(tuple(segments) for segments in local_completion_segments_by_rank), - planner_config=planner_config, - ) - return _score_cp_segment_stats( - rank_local_work=local_work_by_rank, - rank_chain_work=tuple(0 for _ in range(cp_size)), - rank_real_tokens=rank_loads, - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=parent_state_exchange_family_count, - local_bucket_count=local_bucket_count, - local_segment_count=local_segment_count, - chain_bucket_count=0, - planner_config=planner_config, - ) - - -def _can_zero_exchange_colocate_families( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], -) -> bool: - for family in spec.families: - family_rank_counts = [0] * cp_size - for segment in (family.prefix, *family.completions): - segment_counts = segment_attention_counts[_segment_key(segment)] - for rank in range(cp_size): - family_rank_counts[rank] += segment_counts[rank] - if max(family_rank_counts, default=0) != family.token_count: - return False - return True - - -def parse_gdn_shared_prefix_segments( - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - *, - min_completions_per_family: int = 0, -) -> GdnPackedExecutionSpec: - """Parse ART packed shared-prefix metadata into a GDN segment DAG. - - The parser is intentionally strict: GDN state routing depends on prompt-family - boundaries, so malformed metadata should fail before execution can silently - leak recurrent or conv state across siblings or independent families. - """ - - groups = _rank2_long_cpu("group_ids", group_ids) - parents = _rank2_long_cpu("parent_ids", parent_ids) - if tuple(groups.shape) != tuple(parents.shape): - raise ValueError( - "group_ids and parent_ids must have the same shape, got " - f"{tuple(groups.shape)} and {tuple(parents.shape)}" - ) - - batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) - valid_lengths: list[int] = [] - families: list[GdnPackedFamilySpec] = [] - for row_index in range(batch_size): - row_group_ids = groups[row_index] - row_parent_ids = parents[row_index] - valid_length = _validate_padding_tensor( - row_index, row_group_ids, row_parent_ids - ) - valid_lengths.append(valid_length) - if valid_length == 0: - continue - families.extend( - _parse_row_tensor( - row_index=row_index, - group_ids=row_group_ids, - parent_ids=row_parent_ids, - valid_length=valid_length, - first_family_index=len(families), - min_completions_per_family=min_completions_per_family, - ) - ) - - return GdnPackedExecutionSpec( - batch_size=batch_size, - sequence_length=sequence_length, - valid_lengths=tuple(valid_lengths), - families=tuple(families), - ) - - -def _build_segment_bucket_plans( - segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], - *, - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - _build_segment_bucket_plan(bucket[0].length, bucket, device=device) - for bucket in segment_buckets - ) - - -def _build_chunk_aligned_cp1_bucket_plans( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], -]: - boundary_segments: list[GdnSegmentSpec] = [] - tail_segments: list[GdnSegmentSpec] = [] - completion_columns: list[_ExplicitBucketColumn] = [] - for family in spec.families: - prefix = family.prefix - boundary_end = _prefix_chunk_boundary_end(prefix) - if boundary_end > prefix.start: - boundary_segments.append( - _segment_with_bounds(prefix, prefix.start, boundary_end) - ) - prefix_tail_positions = tuple(range(boundary_end, prefix.end)) - if prefix_tail_positions and not family.completions: - tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) - for child_offset, completion in enumerate(family.completions): - completion_positions = prefix_tail_positions + tuple( - range(completion.start, completion.end) - ) - completion_columns.append( - _explicit_bucket_column( - row_index=completion.row_index, - family_index=completion.family_index, - positions=completion_positions, - output_mask=( - ((child_offset == 0),) * len(prefix_tail_positions) - + (True,) * completion.length - ), - ) - ) - boundary_buckets = _batch_segments_by_padded_work( - tuple(boundary_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_buckets = _batch_segments_by_padded_work( - tuple(tail_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_segment_bucket_plans(boundary_buckets, device=device), - _build_segment_bucket_plans(tail_buckets, device=device), - _build_explicit_bucket_plans(completion_column_batches, device=device), - ) - - -def _build_chunk_aligned_position_bucket_plans( - prefix_segments: tuple[GdnSegmentSpec, ...], - completion_segments: tuple[GdnSegmentSpec, ...], - local_token_ranges: tuple[tuple[int, int, int], ...], - *, - sequence_length: int, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], -]: - local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) - local_range_positions = { - (token_start, token_end): position_start - for token_start, token_end, position_start in local_token_ranges - } - completions_by_family: dict[int, list[GdnSegmentSpec]] = {} - for completion in completion_segments: - completions_by_family.setdefault(completion.family_index, []).append(completion) - boundary_segments: list[GdnSegmentSpec] = [] - tail_segments: list[GdnSegmentSpec] = [] - completion_columns: list[_ExplicitBucketColumn] = [] - for prefix in prefix_segments: - boundary_end = _prefix_chunk_boundary_end(prefix) - if boundary_end > prefix.start: - boundary_segments.append( - _segment_with_bounds(prefix, prefix.start, boundary_end) - ) - family_completions = tuple(completions_by_family.get(prefix.family_index, ())) - prefix_tail_positions = _local_positions_for_span( - prefix.row_index, - boundary_end, - prefix.end, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - local_range_positions=local_range_positions, - ) - if prefix_tail_positions and not family_completions: - tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) - for child_offset, completion in enumerate(family_completions): - completion_positions = _local_positions_for_span( - completion.row_index, - completion.start, - completion.end, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - local_range_positions=local_range_positions, - ) - positions = prefix_tail_positions + completion_positions - completion_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=completion.family_index, - positions=positions, - output_mask=( - ((child_offset == 0),) * len(prefix_tail_positions) - + (True,) * len(completion_positions) - ), - ) - ) - boundary_buckets = _batch_segments_by_padded_work( - tuple(boundary_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_buckets = _batch_segments_by_padded_work( - tuple(tail_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_position_bucket_plans( - boundary_buckets, - local_token_ranges, - sequence_length=sequence_length, - device=device, - ), - _build_position_bucket_plans( - tail_buckets, - local_token_ranges, - sequence_length=sequence_length, - device=device, - ), - _build_explicit_bucket_plans(completion_column_batches, device=device), - ) - - -def _build_remote_prefix_tail_plans( - spec: GdnPackedExecutionSpec, - schedule: GdnCpSegmentSchedule, - *, - cp_rank: int, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - Any | None, - Any | None, - tuple[GdnParentStateTransferPlan, ...], - frozenset[int], -]: - from art.megatron.gdn.layout import ( - GdnCpExchangePlan, - GdnCpPeerTransfer, - _reverse_exchange_plan, - ) - - family_by_index = {family.family_index: family for family in spec.families} - prefix_owner_by_family = _prefix_owner_by_family(schedule) - source_positions_by_pair: dict[tuple[int, int], list[int]] = {} - dest_positions_by_pair: dict[tuple[int, int], list[int]] = {} - dest_counts = [0 for _ in schedule.gdn_token_counts_by_rank] - state_transfer_families: dict[tuple[int, int], set[int]] = {} - remote_tail_family_indices: set[int] = set() - local_tail_columns: list[_ExplicitBucketColumn] = [] - local_completion_columns: list[_ExplicitBucketColumn] = [] - tail_positions_by_dest_family: dict[tuple[int, int], tuple[int, ...]] = {} - local_tail_column_families: set[int] = set() - rank_ranges = schedule.gdn_token_ranges_by_rank - rank_range_ends = tuple( - tuple(end for _, end, _ in ranges) for ranges in rank_ranges - ) - rank_range_positions = tuple( - { - (token_start, token_end): position_start - for token_start, token_end, position_start in ranges - } - for ranges in rank_ranges - ) - - for dest_rank, completions in enumerate(schedule.local_completion_segments_by_rank): - for completion in completions: - source_rank = prefix_owner_by_family.get(completion.family_index) - if source_rank is None or source_rank == dest_rank: - continue - family = family_by_index[completion.family_index] - boundary_end = _prefix_chunk_boundary_end(family.prefix) - if boundary_end == family.prefix.end: - continue - dest_family = (dest_rank, family.family_index) - dest_positions = tail_positions_by_dest_family.get(dest_family) - if dest_positions is None: - source_positions = _local_positions_for_span( - family.prefix.row_index, - boundary_end, - family.prefix.end, - sequence_length=spec.sequence_length, - local_token_ranges=rank_ranges[source_rank], - local_range_ends=rank_range_ends[source_rank], - local_range_positions=rank_range_positions[source_rank], - ) - if len(source_positions) != family.prefix.end - boundary_end: - raise ValueError( - "remote prefix-tail exchange could not locate all source tokens " - f"for family {family.family_index}" - ) - dest_start = dest_counts[dest_rank] - dest_positions = tuple( - range(dest_start, dest_start + len(source_positions)) - ) - tail_positions_by_dest_family[dest_family] = dest_positions - dest_counts[dest_rank] += len(source_positions) - pair = (source_rank, dest_rank) - source_positions_by_pair.setdefault(pair, []).extend(source_positions) - dest_positions_by_pair.setdefault(pair, []).extend(dest_positions) - state_transfer_families.setdefault(pair, set()).add(family.family_index) - remote_tail_family_indices.add(family.family_index) - - if dest_rank != cp_rank: - continue - completion_positions = _local_positions_for_span( - completion.row_index, - completion.start, - completion.end, - sequence_length=spec.sequence_length, - local_token_ranges=rank_ranges[dest_rank], - local_range_ends=rank_range_ends[dest_rank], - local_range_positions=rank_range_positions[dest_rank], - ) - if len(completion_positions) != completion.length: - raise ValueError( - "remote prefix-tail bucket could not locate all completion tokens " - f"for family {family.family_index}" - ) - remote_base = int(schedule.gdn_token_counts_by_rank[dest_rank]) - if ( - len(dest_positions) > 0 - and family.family_index not in local_tail_column_families - ): - local_tail_column_families.add(family.family_index) - local_tail_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=family.family_index, - positions=tuple(remote_base + pos for pos in dest_positions), - output_mask=(False,) * len(dest_positions), - ) - ) - local_completion_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=family.family_index, - positions=completion_positions, - output_mask=(True,) * len(completion_positions), - ) - ) - - if not source_positions_by_pair: - return (), (), None, None, (), frozenset() - - transfers = tuple( - GdnCpPeerTransfer.model_construct( - source_rank=source_rank, - dest_rank=dest_rank, - token_count=len(source_positions), - source_positions_tensor=_move_planner_tensor( - torch.tensor(source_positions, dtype=torch.long), device - ), - dest_positions_tensor=_move_planner_tensor( - torch.tensor( - dest_positions_by_pair[(source_rank, dest_rank)], - dtype=torch.long, - ), - device, - ), - ) - for (source_rank, dest_rank), source_positions in sorted( - source_positions_by_pair.items() - ) - ) - exchange = GdnCpExchangePlan.model_construct( - cp_size=len(schedule.gdn_token_counts_by_rank), - source_token_counts_by_rank=schedule.gdn_token_counts_by_rank, - dest_token_counts_by_rank=tuple(dest_counts), - transfers=transfers, - cross_rank_token_count_override=sum(dest_counts), - ) - tail_column_batches = _batch_explicit_bucket_columns( - tuple(local_tail_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(local_completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_explicit_bucket_plans(tail_column_batches, device=device), - _build_explicit_bucket_plans(completion_column_batches, device=device), - exchange, - _reverse_exchange_plan(exchange), - _transfer_plans_to_device( - _build_parent_state_transfer_plans(state_transfer_families), - device=device, - ), - frozenset(remote_tail_family_indices), - ) - - -def _empty_remote_prefix_tail_plans() -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - Any | None, - Any | None, - tuple[GdnParentStateTransferPlan, ...], - frozenset[int], -]: - return (), (), None, None, (), frozenset() - - -def _prefix_owner_by_family(schedule: GdnCpSegmentSchedule) -> dict[int, int]: - owners: dict[int, int] = {} - for rank, segments in enumerate(schedule.local_prefix_segments_by_rank): - for segment in segments: - owners[segment.family_index] = rank - return owners - - -def _filter_parent_state_transfers( - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - excluded_families: frozenset[int], - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - if not excluded_families: - return _transfer_plans_to_device(transfers, device=device) - kept: dict[tuple[int, int], set[int]] = {} - for transfer in transfers: - families = set(transfer.family_indices) - excluded_families - if families: - kept.setdefault((transfer.source_rank, transfer.dest_rank), set()).update( - families - ) - return _transfer_plans_to_device( - _build_parent_state_transfer_plans(kept), device=device - ) - - -def _local_positions_for_span( - row_index: int, - start: int, - end: int, - *, - sequence_length: int, - local_token_ranges: tuple[tuple[int, int, int], ...], - local_range_ends: tuple[int, ...], - local_range_positions: dict[tuple[int, int], int] | None = None, -) -> tuple[int, ...]: - if start == end: - return () - token_start = row_index * sequence_length + start - token_end = row_index * sequence_length + end - if local_range_positions is not None: - position_start = local_range_positions.get((token_start, token_end)) - if position_start is not None: - return tuple(range(position_start, position_start + end - start)) - range_index = bisect_left(local_range_ends, token_start + 1) - if range_index < len(local_token_ranges): - range_start, range_end, position_start = local_token_ranges[range_index] - if range_start <= token_start and token_end <= range_end: - local_start = position_start + token_start - range_start - return tuple(range(local_start, local_start + end - start)) - segment = _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=0, - group_id=0, - parent_id=0, - start=start, - end=end, - kind="prefix", - child_index=None, - ) - return tuple( - int(position) - for position in _local_positions_for_segment( - segment, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - ).tolist() - ) - - -def _prefix_chunk_boundary_end(prefix: GdnSegmentSpec) -> int: - aligned_length = (prefix.length // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE - return prefix.start + aligned_length - - -def _segment_with_bounds( - segment: GdnSegmentSpec, start: int, end: int -) -> GdnSegmentSpec: - return _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=segment.row_index, - family_index=segment.family_index, - group_id=segment.group_id, - parent_id=segment.parent_id, - start=start, - end=end, - kind=segment.kind, - child_index=segment.child_index, - ) - - -def _batch_explicit_bucket_columns( - columns: tuple[_ExplicitBucketColumn, ...], - *, - max_padding_ratio: float = 1.25, - max_segments_per_batch: int = 128, -) -> tuple[tuple[_ExplicitBucketColumn, ...], ...]: - if not columns: - return () - ordered = sorted( - columns, - key=lambda column: (column.length, column.family_index, column.row_index), - ) - batches: list[list[_ExplicitBucketColumn]] = [] - current: list[_ExplicitBucketColumn] = [] - current_tokens = 0 - current_max = 0 - for column in ordered: - next_count = len(current) + 1 - next_tokens = current_tokens + column.length - next_max = max(current_max, column.length) - padded = next_max * next_count - can_extend = not current or ( - next_count <= max_segments_per_batch - and padded <= max_padding_ratio * next_tokens - ) - if not can_extend: - batches.append(current) - current = [] - current_tokens = 0 - current_max = 0 - current.append(column) - current_tokens += column.length - current_max = max(current_max, column.length) - if current: - batches.append(current) - return tuple(tuple(batch) for batch in batches) - - -def _build_explicit_bucket_plans( - bucket_columns: tuple[tuple[_ExplicitBucketColumn, ...], ...], - *, - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - _build_explicit_bucket_plan(columns, device=device) - for columns in bucket_columns - ) - - -def _build_explicit_bucket_plan( - columns: tuple[_ExplicitBucketColumn, ...], - *, - device: torch.device | str, -) -> GdnSegmentBucketPlan: - max_length = max(column.length for column in columns) - column_count = len(columns) - lengths = [column.length for column in columns] - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - padded_element_count = max_length * column_count - row_indices = [0] * padded_element_count - position_indices = [0] * padded_element_count - output_mask = [False] * padded_element_count - for column_index, column in enumerate(columns): - length = column.length - column_slice = slice(column_index, length * column_count, column_count) - row_indices[column_slice] = [column.row_index] * length - position_indices[column_slice] = column.positions - output_mask[column_slice] = column.output_mask - row_indices_cpu = torch.tensor(row_indices, dtype=torch.long).reshape( - max_length, column_count - ) - position_indices_cpu = torch.tensor(position_indices, dtype=torch.long).reshape( - max_length, column_count - ) - output_mask_cpu = torch.tensor(output_mask, dtype=torch.bool).reshape( - max_length, column_count - ) - family_indices_cpu = torch.tensor( - [column.family_index for column in columns], dtype=torch.long - ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), - lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=None, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=int(lengths_cpu.sum().item()), - output_mask=_move_planner_tensor(output_mask_cpu, device), - ) - - -def _attention_source_layout( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None, - planner_config: GdnPlannerConfig, -) -> TokenLayoutIndex: - if attention_token_layout_index is not None: - if _layout_cp_size(attention_token_layout_index) != cp_size: - raise ValueError( - "attention token layout index cp_size must match GDN cp_size, got " - f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" - ) - if _layout_token_count(attention_token_layout_index) != spec.real_token_count: - raise ValueError( - "attention token layout index token count must match GDN real token " - f"count, got {_layout_token_count(attention_token_layout_index)} and " - f"{spec.real_token_count}" - ) - return attention_token_layout_index - return _token_layout_from_rank_ranges( - _default_attention_layout_ranges( - spec, - cp_size=cp_size, - planner_config=planner_config, - ) - ) - - -def _build_cp_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None, - cp_segment_schedule: GdnCpSegmentSchedule | None, - planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan: - if cp_size < 1: - raise ValueError(f"cp_size must be >= 1, got {cp_size}") - if cp_rank < 0 or cp_rank >= cp_size: - raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") - if ( - attention_token_layout_index is not None - and _layout_cp_size(attention_token_layout_index) != cp_size - ): - raise ValueError( - "attention token layout index cp_size must match GDN cp_size, got " - f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" - ) - - from art.megatron.gdn.layout import ( - _reverse_exchange_plan, - build_local_rank_cp_exchange_plan_from_dest_ranges, - ) - - has_explicit_attention_layout = attention_token_layout_index is not None - if cp_segment_schedule is None and not has_explicit_attention_layout: - local_family_plan = _build_local_family_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - planner_config=planner_config, - ) - if local_family_plan is not None: - return local_family_plan - if cp_segment_schedule is None and has_explicit_attention_layout: - local_layout_plan = _build_local_attention_layout_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - if local_layout_plan is not None: - return local_layout_plan - - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - if cp_segment_schedule is None: - schedule = _build_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=_build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max( - 1, - (2 * spec.real_token_count) // max(1, len(spec.segments())), - ), - ), - planner_config=planner_config, - ) - else: - schedule = cp_segment_schedule - if len(schedule.gdn_token_counts_by_rank) != cp_size: - raise ValueError(f"CP GDN schedule must contain {cp_size} ranks") - attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( - source_layout=source_layout, - device=device, - local_rank=cp_rank, - dest_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - cross_rank_token_count=schedule.cross_rank_token_count, - ) - gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) - local_token_ranges = schedule.gdn_token_ranges_by_rank[cp_rank] - local_gdn_token_count = schedule.gdn_token_counts_by_rank[cp_rank] - if schedule.parent_state_exchange_family_indices: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - - chain_prefix_buckets = tuple( - bucket for bucket in schedule.chain_prefix_buckets if bucket - ) - chain_completion_buckets = tuple( - bucket for bucket in schedule.chain_completion_buckets if bucket - ) - local_prefix_segments = tuple(schedule.local_prefix_segments_by_rank[cp_rank]) - local_prefix_family_indices = { - segment.family_index for segment in local_prefix_segments - } - local_prefix_buckets = _batch_segments_by_padded_work( - () if local_prefix_segments else (), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - local_completion_segments = tuple( - schedule.local_completion_segments_by_rank[cp_rank] - ) - chunk_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index in local_prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - plain_local_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=chain_prefix_buckets, - ) - ) - ready_local_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_local_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - local_completion_buckets = ( - ready_local_completion_buckets + remote_local_completion_buckets - ) - prefix_family_order = tuple( - segment.family_index - for bucket in ( - *chain_prefix_buckets, - *local_prefix_buckets, - ) - for segment in bucket - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - local_prefix_segments, - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_gdn_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_gdn_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=_build_position_bucket_plans( - local_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - local_completion_buckets=_build_position_bucket_plans( - local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - ready_local_completion_buckets=_build_position_bucket_plans( - ready_local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - remote_local_completion_buckets=_build_position_bucket_plans( - remote_local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - chain_prefix_buckets=_build_position_bucket_plans( - chain_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - token_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - ), - chain_completion_buckets=_build_position_bucket_plans( - chain_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - token_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - ), - prefix_table_is_dense_ordered=( - not local_prefix_segments - and prefix_family_order == tuple(range(spec.family_count)) - ), - attention_to_gdn=attention_to_gdn, - gdn_to_attention=gdn_to_attention, - attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], - gdn_token_ranges=local_token_ranges, - attention_token_count=source_layout.token_counts_by_rank[cp_rank], - gdn_token_count=local_gdn_token_count, - parent_state_exchange_family_indices=( - tuple( - family_index - for family_index in schedule.parent_state_exchange_family_indices - if family_index not in remote_prefix_tail_families - ) - ), - parent_state_transfers=_filter_parent_state_transfers( - schedule.parent_state_transfers, - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - - -def build_gdn_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None = None, - planner_config: GdnPlannerConfig | None = None, -) -> GdnCpSegmentSchedule: - planner_config = planner_config or GdnPlannerConfig() - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - return _build_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=_build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max( - 1, (2 * spec.real_token_count) // max(1, len(spec.segments())) - ), - ), - planner_config=planner_config, - ) - - -def _build_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - planner_config: GdnPlannerConfig, -) -> GdnCpSegmentSchedule: - segment_attention_counts = _segment_attention_rank_counts( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - legal_chain_segments = tuple( - segment - for family in spec.families - for segment in (family.prefix, *family.completions) - if ( - _can_chain_prefix_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - if segment.kind == "prefix" - else _can_chain_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - ) - ) - decision = _beam_search_cp_segment_schedule_decision( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - segment_attention_counts=segment_attention_counts, - legal_chain_segments=legal_chain_segments, - planner_config=planner_config, - ) - return _materialize_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - segment_attention_counts=segment_attention_counts, - chain_segment_keys=decision.chain_segment_keys, - co_locate_local_families=decision.co_locate_local_families, - planner_config=planner_config, - ) - - -def _beam_search_cp_segment_schedule_decision( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - legal_chain_segments: tuple[GdnSegmentSpec, ...], - planner_config: GdnPlannerConfig, -) -> _GdnCpSegmentSearchDecision: - legal_chain_keys = frozenset( - _segment_key(segment) for segment in legal_chain_segments - ) - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]] = {} - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int] = {} - for segment in legal_chain_segments: - key = _segment_key(segment) - ( - chain_rank_counts_by_key[key], - chain_cross_rank_tokens_by_key[key], - ) = _chain_segment_rank_counts_and_cross_rank_tokens( - segment, - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - - score_cache: dict[ - frozenset[GdnSegmentDecisionKey], _GdnCpSegmentSearchDecision - ] = {} - - def decision_for( - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - ) -> _GdnCpSegmentSearchDecision: - cached = score_cache.get(chain_segment_keys) - if cached is not None: - return cached - non_colocated_score = _score_cp_segment_decisions( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - chain_segment_keys=chain_segment_keys, - co_locate_local_families=False, - planner_config=planner_config, - ) - colocated_score = _score_cp_segment_decisions( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - chain_segment_keys=chain_segment_keys, - co_locate_local_families=True, - planner_config=planner_config, - ) - co_locate = colocated_score < non_colocated_score - decision = _GdnCpSegmentSearchDecision.model_construct( - chain_segment_keys=chain_segment_keys, - co_locate_local_families=co_locate, - score=colocated_score if co_locate else non_colocated_score, - ) - score_cache[chain_segment_keys] = decision - return decision - - best = decision_for(frozenset()) - beam_by_keys = {best.chain_segment_keys: best} - if legal_chain_keys: - all_chain = decision_for(legal_chain_keys) - beam_by_keys[all_chain.chain_segment_keys] = all_chain - if best.score - all_chain.score > planner_config.cp_chain_min_score_delta_ms: - best = all_chain - candidate_groups = _bounded_chain_candidate_groups( - spec, - legal_chain_segments, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ) - beam = _best_cp_segment_search_decisions( - beam_by_keys.values(), - limit=planner_config.cp_chain_beam_width, - ) - stale_steps = 0 - for _ in range(planner_config.cp_chain_beam_max_steps): - if not candidate_groups: - break - expanded: dict[ - frozenset[GdnSegmentDecisionKey], _GdnCpSegmentSearchDecision - ] = {} - for decision in beam: - neighbors = [] - for segment_keys in _chain_beam_neighbor_groups( - decision.chain_segment_keys, - candidate_groups=candidate_groups, - branch_factor=planner_config.cp_chain_beam_branch_factor, - ): - if segment_keys.issubset(decision.chain_segment_keys): - next_keys = decision.chain_segment_keys - segment_keys - else: - next_keys = decision.chain_segment_keys | segment_keys - neighbors.append(decision_for(frozenset(next_keys))) - for neighbor in _best_cp_segment_search_decisions( - neighbors, - limit=planner_config.cp_chain_beam_branch_factor, - ): - expanded[neighbor.chain_segment_keys] = neighbor - if not expanded: - break - beam = _best_cp_segment_search_decisions( - (*beam, *expanded.values()), - limit=planner_config.cp_chain_beam_width, - ) - step_best = beam[0] - if best.score - step_best.score > planner_config.cp_chain_min_score_delta_ms: - best = step_best - stale_steps = 0 - else: - stale_steps += 1 - if stale_steps >= 2: - break - return best - - -def _chain_beam_neighbor_groups( - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - *, - candidate_groups: tuple[frozenset[GdnSegmentDecisionKey], ...], - branch_factor: int, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - selected: list[frozenset[GdnSegmentDecisionKey]] = [] - for group in candidate_groups: - if group and not group.issubset(chain_segment_keys): - selected.append(group) - if len(selected) >= branch_factor: - return tuple(selected) - for group in reversed(candidate_groups): - if group and group.intersection(chain_segment_keys) and group not in selected: - selected.append(group) - if len(selected) >= branch_factor: - break - return tuple(selected) - - -def _best_cp_segment_search_decisions( - decisions: Any, - *, - limit: int, -) -> tuple[_GdnCpSegmentSearchDecision, ...]: - return tuple( - sorted( - decisions, - key=lambda decision: ( - decision.score, - len(decision.chain_segment_keys), - tuple(sorted(decision.chain_segment_keys)), - ), - )[:limit] - ) - - -def _bounded_chain_candidate_groups( - spec: GdnPackedExecutionSpec, - legal_chain_segments: tuple[GdnSegmentSpec, ...], - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - legal_key_set = frozenset(_segment_key(segment) for segment in legal_chain_segments) - if not legal_key_set: - return () - prefix_keys = frozenset( - _segment_key(family.prefix) - for family in spec.families - if _segment_key(family.prefix) in legal_key_set - ) - completion_keys = legal_key_set - prefix_keys - groups: list[frozenset[GdnSegmentDecisionKey]] = [] - for group in (legal_key_set, prefix_keys, completion_keys): - if group and group not in groups: - groups.append(group) - for group in _ranked_chain_beam_groups( - spec, - legal_chain_segments, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ): - if group and group not in groups: - groups.append(group) - return tuple(groups[: planner_config.cp_chain_beam_candidate_limit]) - - -def _ranked_chain_beam_groups( - spec: GdnPackedExecutionSpec, - legal_chain_segments: tuple[GdnSegmentSpec, ...], - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - if not legal_chain_segments: - return () - priority_by_key = { - _segment_key(segment): _chain_beam_segment_priority( - segment, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - ) - for segment in legal_chain_segments - } - legal_key_set = frozenset(priority_by_key) - groups: set[frozenset[GdnSegmentDecisionKey]] = { - frozenset((key,)) for key in legal_key_set - } - for family in spec.families: - completion_keys = frozenset( - _segment_key(completion) - for completion in family.completions - if _segment_key(completion) in legal_key_set - ) - if len(completion_keys) > 1: - groups.add(completion_keys) - family_keys = completion_keys - prefix_key = _segment_key(family.prefix) - if prefix_key in legal_key_set: - family_keys = family_keys | frozenset((prefix_key,)) - if len(family_keys) > 1: - groups.add(family_keys) - ranked = tuple( - sorted( - groups, - key=lambda group: _chain_beam_group_priority( - group, priority_by_key=priority_by_key - ), - reverse=True, - ) - ) - limit = planner_config.cp_chain_beam_candidate_limit - if len(ranked) <= limit: - return ranked - high_count = (limit + 1) // 2 - low_count = limit - high_count - selected = [*ranked[:high_count]] - for group in ranked[-low_count:]: - if group not in selected: - selected.append(group) - return tuple(selected) - - -def _chain_beam_group_priority( - group: frozenset[GdnSegmentDecisionKey], - *, - priority_by_key: dict[GdnSegmentDecisionKey, tuple[int, int, int, int]], -) -> tuple[int, int, int, int, int]: - priorities = tuple(priority_by_key[key] for key in group) - return ( - sum(priority[0] for priority in priorities), - sum(priority[1] for priority in priorities), - max((priority[2] for priority in priorities), default=0), - sum(priority[3] for priority in priorities), - len(group), - ) - - -def _chain_beam_segment_priority( - segment: GdnSegmentSpec, - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], -) -> tuple[int, int, int, int]: - key = _segment_key(segment) - chain_max_load = max(chain_rank_counts_by_key[key], default=0) - best_attention_locality = max(segment_attention_counts[key], default=0) - chain_load_relief = segment.length - chain_max_load - minimum_local_exchange = segment.length - best_attention_locality - return ( - chain_load_relief, - segment.length, - best_attention_locality, - -minimum_local_exchange, - ) - - -def _score_cp_segment_decisions( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int], - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> float: - rank_loads = [0] * cp_size - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - chain_prefix_segments: list[GdnSegmentSpec] = [] - chain_completion_segments: list[GdnSegmentSpec] = [] - parent_state_exchange_families: set[int] = set() - cross_rank_token_count = 0 - - for family in spec.families: - prefix_key = _segment_key(family.prefix) - chain_prefix = prefix_key in chain_segment_keys - local_completions = tuple( - completion - for completion in family.completions - if _segment_key(completion) not in chain_segment_keys - ) - prefix_owner: int | None = None - if chain_prefix: - chain_prefix_segments.append(family.prefix) - cross_rank_token_count += _add_chain_search_load( - rank_loads, - family.prefix, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - ) - else: - owner_segments = ( - (family.prefix, *local_completions) - if co_locate_local_families - else (family.prefix,) - ) - prefix_owner = _best_segment_owner( - owner_segments, - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - cross_rank_token_count += _add_local_search_load( - rank_loads, - prefix_owner, - family.prefix, - segment_attention_counts=segment_attention_counts, - ) - for completion in family.completions: - completion_key = _segment_key(completion) - if completion_key in chain_segment_keys: - chain_completion_segments.append(completion) - cross_rank_token_count += _add_chain_search_load( - rank_loads, - completion, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - ) - if not chain_prefix: - parent_state_exchange_families.add(family.family_index) - continue - if co_locate_local_families and not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "co-located local completion planning lost the prefix owner" - ) - owner = prefix_owner - else: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local completion planning lost the prefix owner" - ) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - local_completion_segments_by_rank[owner].append(completion) - cross_rank_token_count += _add_local_search_load( - rank_loads, - owner, - completion, - segment_attention_counts=segment_attention_counts, - ) - ( - local_work_by_rank, - local_bucket_count, - local_segment_count, - ) = _estimate_local_rank_kernel_work( - tuple(tuple(segments) for segments in local_prefix_segments_by_rank), - tuple(tuple(segments) for segments in local_completion_segments_by_rank), - planner_config=planner_config, - ) - chain_work_by_rank, chain_bucket_count = _estimate_chain_rank_kernel_work( - cp_size=cp_size, - chain_prefix_segments=tuple(chain_prefix_segments), - chain_completion_segments=tuple(chain_completion_segments), - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ) - return _score_cp_segment_stats( - rank_local_work=local_work_by_rank, - rank_chain_work=chain_work_by_rank, - rank_real_tokens=tuple(rank_loads), - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=len(parent_state_exchange_families), - local_bucket_count=local_bucket_count, - local_segment_count=local_segment_count, - chain_bucket_count=chain_bucket_count, - planner_config=planner_config, - ) - - -def _estimate_local_rank_kernel_work( - local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...], - local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...], - *, - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], int, int]: - rank_work: list[int] = [] - rank_bucket_counts: list[int] = [] - rank_segment_counts: list[int] = [] - for prefix_segments, completion_segments in zip( - local_prefix_segments_by_rank, - local_completion_segments_by_rank, - strict=True, - ): - prefix_family_indices = {segment.family_index for segment in prefix_segments} - chunk_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index in prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index not in prefix_family_indices - ) - chunk_work, chunk_bucket_count = _estimate_chunk_aligned_local_work( - prefix_segments, - chunk_local_completion_segments, - planner_config=planner_config, - ) - completion_work, completion_bucket_count = _padded_work_from_lengths( - tuple(segment.length for segment in plain_local_completion_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - rank_work.append(chunk_work + completion_work) - rank_bucket_counts.append(chunk_bucket_count + completion_bucket_count) - rank_segment_counts.append(len(prefix_segments) + len(completion_segments)) - return ( - tuple(rank_work), - max(rank_bucket_counts, default=0), - max(rank_segment_counts, default=0), - ) - - -def _estimate_chunk_aligned_local_work( - prefix_segments: tuple[GdnSegmentSpec, ...], - completion_segments: tuple[GdnSegmentSpec, ...], - *, - planner_config: GdnPlannerConfig, -) -> tuple[int, int]: - completions_by_family: dict[int, list[GdnSegmentSpec]] = {} - for completion in completion_segments: - completions_by_family.setdefault(completion.family_index, []).append(completion) - boundary_lengths: list[int] = [] - tail_lengths: list[int] = [] - completion_column_lengths: list[int] = [] - for prefix in prefix_segments: - boundary_end = _prefix_chunk_boundary_end(prefix) - boundary_length = boundary_end - prefix.start - if boundary_length > 0: - boundary_lengths.append(boundary_length) - tail_length = prefix.end - boundary_end - family_completions = tuple(completions_by_family.get(prefix.family_index, ())) - if tail_length > 0 and not family_completions: - tail_lengths.append(tail_length) - for completion in family_completions: - completion_column_lengths.append(tail_length + completion.length) - boundary_work, boundary_bucket_count = _padded_work_from_lengths( - tuple(boundary_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_work, tail_bucket_count = _padded_work_from_lengths( - tuple(tail_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_work, completion_bucket_count = _padded_work_from_lengths( - tuple(completion_column_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - boundary_work + tail_work + completion_work, - boundary_bucket_count + tail_bucket_count + completion_bucket_count, - ) - - -def _estimate_chain_rank_kernel_work( - *, - cp_size: int, - chain_prefix_segments: tuple[GdnSegmentSpec, ...], - chain_completion_segments: tuple[GdnSegmentSpec, ...], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], int]: - rank_work = [0] * cp_size - bucket_count = 0 - for segments in (chain_prefix_segments, chain_completion_segments): - buckets = _batch_segments_by_padded_work( - segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - bucket_count += len(buckets) - for bucket in buckets: - for rank in range(cp_size): - lengths = tuple( - chain_rank_counts_by_key[_segment_key(segment)][rank] - for segment in bucket - ) - rank_work[rank] += max(lengths, default=0) * len(lengths) - return tuple(rank_work), bucket_count - - -def _padded_work_from_lengths( - lengths: tuple[int, ...], - *, - max_padding_ratio: float, - max_segments_per_batch: int, -) -> tuple[int, int]: - if not lengths: - return 0, 0 - ordered = sorted(length for length in lengths if length > 0) - if not ordered: - return 0, 0 - bucket_count = 0 - padded_work = 0 - current_count = 0 - current_tokens = 0 - current_max = 0 - for length in ordered: - next_count = current_count + 1 - next_tokens = current_tokens + length - next_max = max(current_max, length) - next_padded = next_max * next_count - can_extend = current_count == 0 or ( - next_count <= max_segments_per_batch - and next_padded <= max_padding_ratio * next_tokens - ) - if not can_extend: - bucket_count += 1 - padded_work += current_max * current_count - current_count = 0 - current_tokens = 0 - current_max = 0 - current_count += 1 - current_tokens += length - current_max = max(current_max, length) - if current_count: - bucket_count += 1 - padded_work += current_max * current_count - return padded_work, bucket_count - - -def _add_chain_search_load( - rank_loads: list[int], - segment: GdnSegmentSpec, - *, - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int], -) -> int: - key = _segment_key(segment) - for rank, token_count in enumerate(chain_rank_counts_by_key[key]): - rank_loads[rank] += token_count - return chain_cross_rank_tokens_by_key[key] - - -def _add_local_search_load( - rank_loads: list[int], - rank: int, - segment: GdnSegmentSpec, - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], -) -> int: - rank_loads[rank] += segment.length - return segment.length - segment_attention_counts[_segment_key(segment)][rank] - - -def _chain_segment_rank_counts_and_cross_rank_tokens( - segment: GdnSegmentSpec, - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, -) -> tuple[tuple[int, ...], int]: - token_start = _segment_token_start(segment, spec.sequence_length) - attention_shards = _attention_contiguous_chain_shards( - token_start, - segment.length, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - if attention_shards is not None: - return tuple(len(shard) for shard in attention_shards), 0 - shard_lengths = _fla_aligned_chain_shard_lengths(segment.length, cp_size=cp_size) - cross_rank_tokens = 0 - start = 0 - for rank, shard_length in enumerate(shard_lengths): - end = start + shard_length - shard_start = token_start + start - cross_rank_tokens += shard_length - _attention_overlap_count( - attention_layout_index, - rank, - shard_start, - shard_start + shard_length, - ) - start = end - return shard_lengths, cross_rank_tokens - - -def _materialize_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> GdnCpSegmentSchedule: - gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] - rank_loads = [0] * cp_size - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - chain_prefix_segments: list[GdnSegmentSpec] = [] - chain_completion_segments: list[GdnSegmentSpec] = [] - parent_state_exchange_families: set[int] = set() - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - cross_rank_token_count = 0 - - for family in spec.families: - prefix_key = _segment_key(family.prefix) - chain_prefix = prefix_key in chain_segment_keys - local_completions = tuple( - completion - for completion in family.completions - if _segment_key(completion) not in chain_segment_keys - ) - prefix_owner: int | None = None - if chain_prefix: - chain_prefix_segments.append(family.prefix) - cross_rank_token_count += _append_chain_segment( - gdn_ranges_by_rank, - rank_loads, - family.prefix, - spec, - attention_layout_index=attention_layout_index, - ) - else: - owner_segments = ( - (family.prefix, *local_completions) - if co_locate_local_families - else (family.prefix,) - ) - prefix_owner = _best_segment_owner( - owner_segments, - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - cross_rank_token_count += _append_local_segment( - gdn_ranges_by_rank, - rank_loads, - prefix_owner, - family.prefix, - spec, - segment_attention_counts=segment_attention_counts, - ) - for completion in family.completions: - if _segment_key(completion) in chain_segment_keys: - chain_completion_segments.append(completion) - cross_rank_token_count += _append_chain_segment( - gdn_ranges_by_rank, - rank_loads, - completion, - spec, - attention_layout_index=attention_layout_index, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local-prefix/chained-completion planning lost the prefix owner" - ) - parent_state_exchange_families.add(family.family_index) - for dest_rank in range(cp_size): - if dest_rank == prefix_owner: - continue - parent_state_transfer_families.setdefault( - (prefix_owner, dest_rank), set() - ).add(family.family_index) - continue - if co_locate_local_families and not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "co-located local completion planning lost the prefix owner" - ) - owner = prefix_owner - else: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local completion planning lost the prefix owner" - ) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - parent_state_transfer_families.setdefault( - (prefix_owner, owner), set() - ).add(family.family_index) - local_completion_segments_by_rank[owner].append(completion) - cross_rank_token_count += _append_local_segment( - gdn_ranges_by_rank, - rank_loads, - owner, - completion, - spec, - segment_attention_counts=segment_attention_counts, - ) - - return GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=tuple(rank_loads), - gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - cross_rank_token_count=cross_rank_token_count, - chain_prefix_buckets=_batch_segments_by_padded_work( - tuple(chain_prefix_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ), - chain_completion_buckets=_batch_segments_by_padded_work( - tuple(chain_completion_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ), - local_prefix_segments_by_rank=tuple( - tuple(segments) for segments in local_prefix_segments_by_rank - ), - local_completion_segments_by_rank=tuple( - tuple(segments) for segments in local_completion_segments_by_rank - ), - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families) - ), - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - - -def _build_local_family_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan | None: - if cp_size <= 1 or not spec.families: - return None - target_rank_load = spec.real_token_count / cp_size - loads = [0] * cp_size - prefix_owner_by_family: list[int] = [] - completion_owners_by_family: list[tuple[int, ...]] = [] - for family in spec.families: - if _has_chainable_segment( - family, cp_size=cp_size, planner_config=planner_config - ): - return None - prefix_locality_limit = max( - planner_config.max_zero_exchange_load_imbalance * target_rank_load, - min(64.0, float(spec.real_token_count)), - ) - if family.prefix.length > prefix_locality_limit: - return None - owner = _least_loaded_rank(loads) - prefix_owner_by_family.append(owner) - completion_owners_by_family.append(tuple(owner for _ in family.completions)) - loads[owner] += family.token_count - - if max(loads, default=0) > ( - planner_config.local_completion_rebalance_min_imbalance * target_rank_load - ): - completion_owners_by_family = list( - _rebalance_local_completion_segments( - spec, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - initial_loads=tuple(loads), - planner_config=planner_config, - ) - ) - rank_assignments = _materialize_local_family_rank_assignments( - spec, - cp_size=cp_size, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - ) - local_token_count, local_token_ranges, prefix_segments, completion_segments = ( - rank_assignments[cp_rank] - ) - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - completion_owners = completion_owners_by_family[family.family_index] - for completion_owner in sorted(set(completion_owners)): - if completion_owner == prefix_owner: - continue - parent_state_transfer_families.setdefault( - (prefix_owner, completion_owner), set() - ).add(family.family_index) - - from art.megatron.gdn.layout import GdnCpExchangePlan, GdnCpPeerTransfer - - token_counts_by_rank = tuple(assignment[0] for assignment in rank_assignments) - identity_exchange = GdnCpExchangePlan.model_construct( - cp_size=cp_size, - source_token_counts_by_rank=token_counts_by_rank, - dest_token_counts_by_rank=token_counts_by_rank, - transfers=tuple( - GdnCpPeerTransfer.model_construct( - source_rank=rank, - dest_rank=rank, - token_count=token_count, - source_positions_tensor=None, - dest_positions_tensor=None, - ) - for rank, token_count in enumerate(token_counts_by_rank) - if token_count - ), - ) - parent_state_exchange_family_indices = tuple( - sorted( - family_index - for family_indices in parent_state_transfer_families.values() - for family_index in family_indices - ) - ) - schedule = GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=token_counts_by_rank, - gdn_token_ranges_by_rank=tuple( - assignment[1] for assignment in rank_assignments - ), - cross_rank_token_count=0, - chain_prefix_buckets=(), - chain_completion_buckets=(), - local_prefix_segments_by_rank=tuple( - assignment[2] for assignment in rank_assignments - ), - local_completion_segments_by_rank=tuple( - assignment[3] for assignment in rank_assignments - ), - parent_state_exchange_family_indices=parent_state_exchange_family_indices, - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - if parent_state_exchange_family_indices: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - local_prefix_family_indices = {segment.family_index for segment in prefix_segments} - chunk_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index in local_prefix_family_indices - ) - suffix_only_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - suffix_only_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=(), - ) - ) - ready_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - ready_completion_bucket_plans = _build_position_bucket_plans( - ready_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - remote_completion_bucket_plans = _build_position_bucket_plans( - remote_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - local_completion_bucket_plans = ( - ready_completion_bucket_plans + remote_completion_bucket_plans - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - prefix_segments, - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=local_completion_bucket_plans, - ready_local_completion_buckets=ready_completion_bucket_plans, - remote_local_completion_buckets=remote_completion_bucket_plans, - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=( - tuple(segment.family_index for segment in prefix_segments) - == tuple(range(spec.family_count)) - ), - attention_to_gdn=identity_exchange, - gdn_to_attention=identity_exchange, - attention_token_ranges=local_token_ranges, - gdn_token_ranges=local_token_ranges, - attention_token_count=local_token_count, - gdn_token_count=local_token_count, - parent_state_exchange_family_indices=tuple( - family_index - for family_index in parent_state_exchange_family_indices - if family_index not in remote_prefix_tail_families - ), - parent_state_transfers=_filter_parent_state_transfers( - _build_parent_state_transfer_plans(parent_state_transfer_families), - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - - -def _rebalance_local_completion_segments( - spec: GdnPackedExecutionSpec, - *, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], - initial_loads: tuple[int, ...], - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], ...]: - owners = [list(family_owners) for family_owners in completion_owners_by_family] - loads = list(initial_loads) - remote_owners_by_family = [ - { - owner - for owner in family_owners - if owner != prefix_owner_by_family[family_index] - } - for family_index, family_owners in enumerate(owners) - ] - transfer_count = sum( - len(remote_owners) for remote_owners in remote_owners_by_family - ) - - def score(candidate_loads: list[int], candidate_transfer_count: int) -> float: - max_load = max(candidate_loads, default=0) - idle_tokens = sum(max_load - load for load in candidate_loads) - return ( - max_load - + planner_config.rank_idle_token_cost * idle_tokens - + planner_config.parent_state_exchange_penalty_tokens - * candidate_transfer_count - ) - - best_score = score(loads, transfer_count) - while True: - best_move: ( - tuple[int, int, int, tuple[int, ...], list[int], int, float] | None - ) = None - for family in spec.families: - family_owners = owners[family.family_index] - prefix_owner = prefix_owner_by_family[family.family_index] - original_remote_owners = remote_owners_by_family[family.family_index] - for source in sorted(set(family_owners)): - source_children = [ - child_index - for child_index, owner in enumerate(family_owners) - if owner == source - ] - ordered_children = sorted( - source_children, - key=lambda child_index: family.completions[child_index].length, - reverse=True, - ) - for dest in range(len(loads)): - if dest == source: - continue - moved_tokens = 0 - moved_children = [] - for child_index in ordered_children: - moved_tokens += family.completions[child_index].length - moved_children.append(child_index) - candidate_loads = list(loads) - candidate_loads[source] -= moved_tokens - candidate_loads[dest] += moved_tokens - candidate_remote_owners = set(original_remote_owners) - if source != prefix_owner and len(moved_children) == len( - source_children - ): - candidate_remote_owners.discard(source) - if dest != prefix_owner: - candidate_remote_owners.add(dest) - candidate_transfer_count = ( - transfer_count - - len(original_remote_owners) - + len(candidate_remote_owners) - ) - candidate_score = score( - candidate_loads, candidate_transfer_count - ) - if candidate_score >= best_score: - continue - if best_move is None or candidate_score < best_move[-1]: - best_move = ( - family.family_index, - source, - dest, - tuple(moved_children), - candidate_loads, - candidate_transfer_count, - candidate_score, - ) - if best_move is None: - return tuple(tuple(item) for item in owners) - ( - family_index, - _source, - dest, - moved_children, - loads, - transfer_count, - best_score, - ) = best_move - for child_index in moved_children: - owners[family_index][child_index] = dest - prefix_owner = prefix_owner_by_family[family_index] - remote_owners_by_family[family_index] = { - owner for owner in set(owners[family_index]) if owner != prefix_owner - } - - -def _materialize_local_family_rank_assignments( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], -) -> tuple[ - tuple[ - int, - tuple[tuple[int, int, int], ...], - tuple[GdnSegmentSpec, ...], - tuple[GdnSegmentSpec, ...], - ], - ..., -]: - token_ranges_by_rank: list[list[tuple[int, int, int]]] = [ - [] for _ in range(cp_size) - ] - token_counts_by_rank = [0] * cp_size - prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [[] for _ in range(cp_size)] - completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - sequence_length = spec.sequence_length - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - prefix_segments_by_rank[prefix_owner].append(family.prefix) - prefix_token_start = ( - family.prefix.row_index * sequence_length + family.prefix.start - ) - prefix_position_start = token_counts_by_rank[prefix_owner] - token_ranges_by_rank[prefix_owner].append( - ( - prefix_token_start, - prefix_token_start + family.prefix.length, - prefix_position_start, - ) - ) - token_counts_by_rank[prefix_owner] = ( - prefix_position_start + family.prefix.length - ) - for completion, completion_owner in zip( - family.completions, - completion_owners_by_family[family.family_index], - strict=True, - ): - completion_segments_by_rank[completion_owner].append(completion) - completion_token_start = ( - completion.row_index * sequence_length + completion.start - ) - completion_position_start = token_counts_by_rank[completion_owner] - token_ranges_by_rank[completion_owner].append( - ( - completion_token_start, - completion_token_start + completion.length, - completion_position_start, - ) +) -> TokenLayoutIndex: + if attention_token_layout_index is not None: + if _layout_cp_size(attention_token_layout_index) != cp_size: + raise ValueError( + "attention token layout index cp_size must match GDN cp_size, got " + f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" ) - token_counts_by_rank[completion_owner] = ( - completion_position_start + completion.length + if _layout_token_count(attention_token_layout_index) != spec.real_token_count: + raise ValueError( + "attention token layout index token count must match GDN real token " + f"count, got {_layout_token_count(attention_token_layout_index)} and " + f"{spec.real_token_count}" ) - return tuple( - ( - token_counts_by_rank[rank], - tuple(token_ranges_by_rank[rank]), - tuple(prefix_segments_by_rank[rank]), - tuple(completion_segments_by_rank[rank]), - ) - for rank in range(cp_size) - ) - - -def _empty_local_family_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, -) -> GdnRankExecutionPlan: - from art.megatron.gdn.layout import GdnCpExchangePlan - - identity_exchange = GdnCpExchangePlan.model_construct( - cp_size=cp_size, - source_token_counts_by_rank=tuple(0 for _ in range(cp_size)), - dest_token_counts_by_rank=tuple(0 for _ in range(cp_size)), - transfers=(), - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=0, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones(1, 0, device=device, dtype=torch.bool), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=(), - ready_local_completion_buckets=(), - remote_local_completion_buckets=(), - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=False, - attention_to_gdn=identity_exchange, - gdn_to_attention=identity_exchange, - attention_token_ranges=(), - gdn_token_ranges=(), - attention_token_count=0, - gdn_token_count=0, - parent_state_exchange_family_indices=(), - parent_state_transfers=(), + return attention_token_layout_index + return _token_layout_from_rank_ranges( + _default_attention_layout_ranges( + spec, + cp_size=cp_size, + planner_config=planner_config, + ) ) @@ -3179,143 +689,56 @@ def _can_chain_segment( if segment.kind == "prefix" else planner_config.cp_chain_min_total_tokens ) - if segment.length < min_tokens: - return False - if segment.length < cp_size: - return False - if segment.length // FLA_CHUNK_SIZE < cp_size: - return False - per_rank = segment.length / cp_size - if per_rank < planner_config.cp_chain_min_tokens_per_rank: - return False - return True - - -def _build_parent_state_transfer_plans( - families_by_peer: dict[tuple[int, int], set[int]], -) -> tuple[GdnParentStateTransferPlan, ...]: - return tuple( - GdnParentStateTransferPlan( - source_rank=source_rank, - dest_rank=dest_rank, - family_indices=tuple(sorted(family_indices)), - ) - for (source_rank, dest_rank), family_indices in sorted(families_by_peer.items()) - if source_rank != dest_rank and family_indices - ) - - -def _split_ready_and_remote_completion_segments( - completion_segments: tuple[GdnSegmentSpec, ...], - *, - local_prefix_segments: tuple[GdnSegmentSpec, ...], - chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], -) -> tuple[tuple[GdnSegmentSpec, ...], tuple[GdnSegmentSpec, ...]]: - ready_family_indices = { - segment.family_index for segment in local_prefix_segments - } | {segment.family_index for bucket in chain_prefix_buckets for segment in bucket} - ready = [] - remote = [] - for segment in completion_segments: - if segment.family_index in ready_family_indices: - ready.append(segment) - else: - remote.append(segment) - return tuple(ready), tuple(remote) - - -def _transfer_plans_to_device( - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - return tuple( - transfer.model_copy( - update={ - "family_indices_tensor": _move_planner_tensor( - torch.tensor(transfer.family_indices, dtype=torch.long), - device, - ) - } - ) - for transfer in transfers + return _can_chain_segment_with_min_tokens( + segment, + cp_size=cp_size, + min_tokens=min_tokens, + planner_config=planner_config, ) -def _has_chainable_segment( - family: GdnPackedFamilySpec, +def _can_chain_tree_segment( + segment: GdnSegmentSpec, *, cp_size: int, planner_config: GdnPlannerConfig, ) -> bool: - return _can_chain_prefix_segment( - family.prefix, cp_size=cp_size, planner_config=planner_config - ) or any( - _can_chain_segment(completion, cp_size=cp_size, planner_config=planner_config) - for completion in family.completions + min_tokens = ( + min( + planner_config.cp_tree_chain_min_prefix_only_tokens, + planner_config.cp_chain_min_prefix_only_tokens, + ) + if segment.kind == "prefix" + else min( + planner_config.cp_tree_chain_min_total_tokens, + planner_config.cp_chain_min_total_tokens, + ) + ) + return _can_chain_segment_with_min_tokens( + segment, + cp_size=cp_size, + min_tokens=min_tokens, + planner_config=planner_config, ) -def _can_chain_prefix_segment( +def _can_chain_segment_with_min_tokens( segment: GdnSegmentSpec, *, cp_size: int, + min_tokens: int, planner_config: GdnPlannerConfig, ) -> bool: - return _can_chain_segment(segment, cp_size=cp_size, planner_config=planner_config) - - -def _score_cp_segment_stats( - *, - rank_local_work: tuple[int, ...], - rank_chain_work: tuple[int, ...], - rank_real_tokens: tuple[int, ...], - cross_rank_token_count: int, - parent_state_exchange_family_count: int, - local_bucket_count: int, - local_segment_count: int, - chain_bucket_count: int, - planner_config: GdnPlannerConfig, -) -> float: - empty_rank_count = sum(1 for token_count in rank_real_tokens if token_count == 0) - return ( - _rank_kernel_ms( - rank_local_work, - rank_chain_work, - local_token_ms=planner_config.planner_local_token_ms, - chain_token_ms=planner_config.planner_chain_token_ms, - ) - + planner_config.planner_local_bucket_ms * local_bucket_count - + planner_config.planner_chain_bucket_ms * chain_bucket_count - + planner_config.planner_local_segment_ms * local_segment_count - + planner_config.planner_layout_cross_rank_token_ms * cross_rank_token_count - + ( - planner_config.planner_parent_state_exchange_base_ms - + planner_config.planner_parent_state_exchange_ms - * parent_state_exchange_family_count - if parent_state_exchange_family_count - else 0.0 - ) - + planner_config.planner_empty_rank_ms * empty_rank_count - ) - - -def _rank_kernel_ms( - rank_local_work: tuple[int, ...], - rank_chain_work: tuple[int, ...], - *, - local_token_ms: float, - chain_token_ms: float, -) -> float: - return max( - ( - local_work * local_token_ms + chain_work * chain_token_ms - for local_work, chain_work in zip( - rank_local_work, rank_chain_work, strict=True - ) - ), - default=0.0, - ) + if segment.length < min_tokens: + return False + if segment.length < cp_size: + return False + if segment.length // FLA_CHUNK_SIZE < cp_size: + return False + per_rank = segment.length / cp_size + if per_rank < planner_config.cp_chain_min_tokens_per_rank: + return False + return True def _best_segment_owner( @@ -3336,11 +759,16 @@ def _best_segment_owner( for rank in range(rank_count): counts_by_rank[rank] += segment_counts[rank] on_rank_tokens = tuple(counts_by_rank) - best: tuple[float, int, int, int, int] | None = None + best: tuple[float, float, int, int, int, int] | None = None for rank, tokens in enumerate(on_rank_tokens): projected_loads = list(rank_loads) projected_loads[rank] += segment_length max_load = max(projected_loads, default=0) + target_load = sum(projected_loads) / max(1, len(projected_loads)) + overload = max( + 0.0, + max_load - planner_config.max_zero_exchange_load_imbalance * target_load, + ) idle_tokens = sum(max_load - load for load in projected_loads) cross_rank_tokens = segment_length - int(tokens) empty_rank_count = sum(1 for load in projected_loads if load == 0) @@ -3353,6 +781,7 @@ def _best_segment_owner( + empty_rank_count * planner_config.planner_empty_rank_ms ) candidate = ( + overload, score, max_load, cross_rank_tokens, @@ -3366,6 +795,118 @@ def _best_segment_owner( return best[-1] +def _build_tree_state_exchanges_by_depth( + spec: GdnPackedExecutionSpec, + *, + owner_by_node: tuple[int, ...], + chained_nodes: tuple[bool, ...], + cp_rank: int, + cp_size: int, + depth_count: int, + device: torch.device | str, +) -> tuple[GdnStateExchangePlan | None, ...]: + if cp_size <= 1: + return tuple(None for _ in range(depth_count)) + + from art.megatron.gdn.layout import ( + GdnCpExchangePlan, + _make_peer_transfer, + _reverse_exchange_plan, + ) + + families_by_depth_pair: list[dict[tuple[int, int], set[int]]] = [ + {} for _ in range(depth_count) + ] + for child_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0 or chained_nodes[parent_index]: + continue + source_rank = owner_by_node[parent_index] + dest_rank = owner_by_node[child_index] + if source_rank < 0 or dest_rank < 0: + raise ValueError("tree state exchange requires every node to have an owner") + if source_rank == dest_rank: + continue + depth = spec.tree_depths[child_index] + families_by_depth_pair[depth].setdefault((source_rank, dest_rank), set()).add( + parent_index + ) + + state_exchanges: list[GdnStateExchangePlan | None] = [] + for pair_families in families_by_depth_pair: + if not pair_families: + state_exchanges.append(None) + continue + source_families_by_rank = [set[int]() for _ in range(cp_size)] + dest_families_by_rank = [set[int]() for _ in range(cp_size)] + for (source_rank, dest_rank), parent_indices in pair_families.items(): + source_families_by_rank[source_rank].update(parent_indices) + dest_families_by_rank[dest_rank].update(parent_indices) + source_families = tuple( + tuple(sorted(families)) for families in source_families_by_rank + ) + dest_families = tuple( + tuple(sorted(families)) for families in dest_families_by_rank + ) + source_positions = ( + {family: index for index, family in enumerate(families)} + for families in source_families + ) + dest_positions = ( + {family: index for index, family in enumerate(families)} + for families in dest_families + ) + source_position_by_rank = tuple(source_positions) + dest_position_by_rank = tuple(dest_positions) + transfers = [] + transfer_count = 0 + for (source_rank, dest_rank), parent_indices in sorted(pair_families.items()): + ordered = tuple(sorted(parent_indices)) + transfer_count += len(ordered) + transfers.append( + _make_peer_transfer( + source_rank=source_rank, + dest_rank=dest_rank, + source_positions=torch.tensor( + [ + source_position_by_rank[source_rank][family] + for family in ordered + ], + dtype=torch.long, + ), + dest_positions=torch.tensor( + [ + dest_position_by_rank[dest_rank][family] + for family in ordered + ], + dtype=torch.long, + ), + source_count=len(source_families[source_rank]), + dest_count=len(dest_families[dest_rank]), + device=device, + ) + ) + exchange = GdnCpExchangePlan.model_construct( + cp_size=cp_size, + source_token_counts_by_rank=tuple( + len(families) for families in source_families + ), + dest_token_counts_by_rank=tuple( + len(families) for families in dest_families + ), + transfers=tuple(transfers), + cross_rank_token_count_override=transfer_count, + ) + state_exchanges.append( + GdnStateExchangePlan.model_construct( + source_family_indices=source_families[cp_rank], + dest_family_indices=dest_families[cp_rank], + exchange=exchange, + reverse_exchange=_reverse_exchange_plan(exchange), + ) + ) + return tuple(state_exchanges) + + def _build_attention_layout_index_from_token_layout( layout: TokenLayoutIndex, *, @@ -3472,61 +1013,22 @@ def should_split_segment(segment: GdnSegmentSpec) -> bool: target_rank_load ): return False - if segment.kind == "prefix": - return _can_chain_prefix_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - return _can_chain_segment( + return _can_chain_tree_segment( segment, cp_size=cp_size, planner_config=planner_config ) - for family in spec.families: - has_split_segment = any( - should_split_segment(segment) - for segment in (family.prefix, *family.completions) - ) - if not has_split_segment: - if _should_co_locate_non_chain_family( - family, - total_real_tokens=spec.real_token_count, - cp_size=cp_size, - planner_config=planner_config, - ): - owner = _least_loaded_rank(loads) - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - append_segment(owner, token_start, segment.length) - continue - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - owner = _least_loaded_rank(loads) - append_segment(owner, token_start, segment.length) + for segment in spec.tree_segments: + token_start = _segment_token_start(segment, spec.sequence_length) + if should_split_segment(segment): + _append_split_default_attention_segment( + ranks, loads, token_start, segment.length + ) continue - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - if should_split_segment(segment): - _append_split_default_attention_segment( - ranks, loads, token_start, segment.length - ) - continue - owner = _least_loaded_rank(loads) - append_segment(owner, token_start, segment.length) + owner = _least_loaded_rank(loads) + append_segment(owner, token_start, segment.length) return tuple(tuple(ranges) for ranges in ranks) -def _should_co_locate_non_chain_family( - family: GdnPackedFamilySpec, - *, - total_real_tokens: int, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> bool: - target_rank_load = total_real_tokens / cp_size - return family.token_count <= ( - planner_config.max_zero_exchange_load_imbalance * target_rank_load - ) - - def _append_split_default_attention_segment( ranks: list[list[tuple[int, int, int]]], loads: list[int], @@ -3591,26 +1093,6 @@ def _append_chain_segment( return cross_rank_tokens -def _chain_rank_token_indices( - segment: GdnSegmentSpec, - spec: GdnPackedExecutionSpec, - *, - cp_rank: int, - cp_size: int, -) -> range: - token_start = _segment_token_start(segment, spec.sequence_length) - lengths = _fla_aligned_chain_shard_lengths(segment.length, cp_size=cp_size) - start = sum(lengths[:cp_rank]) - end = start + lengths[cp_rank] - if start >= end: - raise ValueError( - "CP chain planning requires non-empty shards; " - f"segment={segment.kind}:{segment.family_index} " - f"length={segment.length} cp_size={cp_size}" - ) - return range(token_start + start, token_start + end) - - def _fla_aligned_chain_shard_lengths(length: int, *, cp_size: int) -> tuple[int, ...]: full_chunks = int(length) // FLA_CHUNK_SIZE if full_chunks < int(cp_size): @@ -3695,14 +1177,99 @@ def _least_loaded_rank(rank_loads: list[int]) -> int: return min(range(len(rank_loads)), key=lambda rank: (rank_loads[rank], rank)) -def _owner_rank( - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]], - prefix: GdnSegmentSpec, -) -> int: - for rank, segments in enumerate(local_prefix_segments_by_rank): - if prefix in segments: - return rank - raise RuntimeError("local prefix owner was not recorded") +def _build_tree_segment_bucket_plans( + segments: tuple[GdnSegmentSpec, ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], + *, + device: torch.device | str, + planner_config: GdnPlannerConfig, +) -> tuple[GdnSegmentBucketPlan, ...]: + segment_buckets = _batch_tree_segments_by_padded_work( + segments, + tree_has_children, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + plans = _build_segment_bucket_plans(segment_buckets, device=device) + return tuple( + _bucket_with_tree_parent_indices( + plan, + bucket, + tree_parent_indices, + tree_has_children, + device=device, + ) + for plan, bucket in zip(plans, segment_buckets, strict=True) + ) + + +def _build_tree_position_bucket_plans( + segments: tuple[GdnSegmentSpec, ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], + local_token_ranges: tuple[tuple[int, int, int], ...], + *, + sequence_length: int, + device: torch.device | str, + planner_config: GdnPlannerConfig, + token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, + split_by_final_state: bool = True, +) -> tuple[GdnSegmentBucketPlan, ...]: + segment_buckets = ( + _batch_tree_segments_by_padded_work( + segments, + tree_has_children, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + if split_by_final_state + else _batch_segments_by_padded_work( + segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + ) + plans = _build_position_bucket_plans( + segment_buckets, + local_token_ranges, + sequence_length=sequence_length, + device=device, + token_ranges_by_rank=token_ranges_by_rank, + ) + return tuple( + _bucket_with_tree_parent_indices( + plan, + bucket, + tree_parent_indices, + tree_has_children, + device=device, + ) + for plan, bucket in zip(plans, segment_buckets, strict=True) + ) + + +def _bucket_with_tree_parent_indices( + plan: GdnSegmentBucketPlan, + segments: tuple[GdnSegmentSpec, ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], + *, + device: torch.device | str, +) -> GdnSegmentBucketPlan: + parent_indices = torch.tensor( + [tree_parent_indices[segment.family_index] for segment in segments], + dtype=torch.long, + ) + return plan.model_copy( + update={ + "parent_indices": _move_planner_tensor(parent_indices, device), + "parent_indices_cpu": parent_indices, + "needs_final_state": any( + tree_has_children[segment.family_index] for segment in segments + ), + } + ) def _build_position_bucket_plans( @@ -3791,6 +1358,7 @@ def _build_position_bucket_plan( row_indices=_move_planner_tensor(row_indices_cpu, device), position_indices=_move_planner_tensor(position_indices_cpu, device), family_indices=_move_planner_tensor(family_indices_cpu, device), + family_indices_cpu=family_indices_cpu, real_token_count_static=sum(lengths), ) @@ -3850,6 +1418,7 @@ def _build_exact_range_position_bucket_plan( row_indices=_move_planner_tensor(row_indices_cpu, device), position_indices=_move_planner_tensor(position_indices_cpu, device), family_indices=_move_planner_tensor(family_indices_cpu, device), + family_indices_cpu=family_indices_cpu, real_token_count_static=sum(lengths), ) @@ -3927,6 +1496,33 @@ def _batch_segments_by_padded_work( return tuple(tuple(batch) for batch in batches) +def _batch_tree_segments_by_padded_work( + segments: tuple[GdnSegmentSpec, ...], + tree_has_children: tuple[bool, ...], + *, + max_padding_ratio: float = 1.25, + max_segments_per_batch: int = 128, +) -> tuple[tuple[GdnSegmentSpec, ...], ...]: + stateful = tuple( + segment for segment in segments if tree_has_children[segment.family_index] + ) + stateless = tuple( + segment for segment in segments if not tree_has_children[segment.family_index] + ) + return ( + *_batch_segments_by_padded_work( + stateful, + max_padding_ratio=max_padding_ratio, + max_segments_per_batch=max_segments_per_batch, + ), + *_batch_segments_by_padded_work( + stateless, + max_padding_ratio=max_padding_ratio, + max_segments_per_batch=max_segments_per_batch, + ), + ) + + def _build_segment_bucket_plan( length: int, segments: tuple[GdnSegmentSpec, ...], *, device: torch.device | str ) -> GdnSegmentBucketPlan: @@ -3961,6 +1557,7 @@ def _build_segment_bucket_plan( ), position_indices=_move_planner_tensor(positions_cpu, device), family_indices=_move_planner_tensor(family_indices_cpu, device), + family_indices_cpu=family_indices_cpu, real_token_count_static=sum(segment.length for segment in segments), ) @@ -4012,27 +1609,6 @@ def _range_overlaps( return overlaps -def _local_token_ranges( - local_gdn_tokens: tuple[int, ...], -) -> tuple[tuple[int, int, int], ...]: - if not local_gdn_tokens: - return () - ranges = [] - token_start = local_gdn_tokens[0] - token_end = token_start + 1 - position_start = 0 - for position, token in enumerate(local_gdn_tokens[1:], start=1): - if token == token_end: - token_end += 1 - continue - ranges.append((token_start, token_end, position_start)) - token_start = token - token_end = token + 1 - position_start = position - ranges.append((token_start, token_end, position_start)) - return tuple(ranges) - - def _local_positions_for_segment( segment: GdnSegmentSpec, *, @@ -4079,285 +1655,3 @@ def _rank2_long_cpu(name: str, tensor: torch.Tensor) -> torch.Tensor: ): raise TypeError(f"{name} must contain integer ids, got dtype={tensor.dtype}") return tensor.detach().to(device="cpu", dtype=torch.long) - - -def _validate_padding_tensor( - row_index: int, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, -) -> int: - padding_positions = torch.nonzero(group_ids == -1, as_tuple=False) - valid_length = ( - int(padding_positions[0].item()) - if int(padding_positions.numel()) > 0 - else int(group_ids.numel()) - ) - if valid_length == 0: - if bool(torch.any(parent_ids != -1).item()): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return 0 - if bool(torch.any(group_ids[valid_length:] != -1).item()): - raise ValueError( - f"row {row_index}: valid tokens must be contiguous before padding" - ) - if bool(torch.any(parent_ids[:valid_length] == -1).item()): - raise ValueError( - f"row {row_index}: valid tokens must have non-padding parent_ids" - ) - if bool(torch.any(parent_ids[valid_length:] != -1).item()): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return valid_length - - -def _validate_padding( - row_index: int, - group_ids: list[int], - parent_ids: list[int], -) -> int: - valid_length = 0 - for group_id in group_ids: - if group_id == -1: - break - valid_length += 1 - if valid_length == 0: - if any(parent_id != -1 for parent_id in parent_ids): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return 0 - if any(group_id != -1 for group_id in group_ids[valid_length:]): - raise ValueError( - f"row {row_index}: valid tokens must be contiguous before padding" - ) - if any(parent_id == -1 for parent_id in parent_ids[:valid_length]): - raise ValueError( - f"row {row_index}: valid tokens must have non-padding parent_ids" - ) - if any(parent_id != -1 for parent_id in parent_ids[valid_length:]): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return valid_length - - -def _parse_row_tensor( - *, - row_index: int, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - valid_length: int, - first_family_index: int, - min_completions_per_family: int, -) -> list[GdnPackedFamilySpec]: - valid_groups = group_ids[:valid_length] - valid_parents = parent_ids[:valid_length] - if valid_length > 1: - same_group = valid_groups[1:] == valid_groups[:-1] - parent_changed = same_group & (valid_parents[1:] != valid_parents[:-1]) - if bool(torch.any(parent_changed).item()): - position = int(torch.nonzero(parent_changed, as_tuple=False)[0].item()) + 1 - group_id = int(valid_groups[position].item()) - previous_parent = int(valid_parents[position - 1].item()) - current_parent = int(valid_parents[position].item()) - raise ValueError( - f"row {row_index}: group {group_id} changes parent from " - f"{previous_parent} to {current_parent}" - ) - boundaries = torch.nonzero(~same_group, as_tuple=False).flatten() + 1 - starts_tensor = torch.cat( - (valid_groups.new_zeros(1), boundaries.to(valid_groups.dtype)) - ) - ends_tensor = torch.cat( - ( - boundaries.to(valid_groups.dtype), - valid_groups.new_tensor([valid_length]), - ) - ) - else: - starts_tensor = valid_groups.new_zeros(1) - ends_tensor = valid_groups.new_tensor([valid_length]) - - starts = tuple(int(value) for value in starts_tensor.tolist()) - ends = tuple(int(value) for value in ends_tensor.tolist()) - segment_group_ids = tuple(int(valid_groups[start].item()) for start in starts) - segment_parent_ids = tuple(int(valid_parents[start].item()) for start in starts) - families: list[GdnPackedFamilySpec] = [] - seen_groups: set[int] = set() - segment_cursor = 0 - while segment_cursor < len(starts): - group_id = segment_group_ids[segment_cursor] - parent_id = segment_parent_ids[segment_cursor] - start = starts[segment_cursor] - end = ends[segment_cursor] - if group_id in seen_groups: - raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") - if group_id != parent_id: - raise ValueError( - f"row {row_index}: completion group {group_id} appears before " - f"its prefix parent {parent_id}" - ) - seen_groups.add(group_id) - family_index = first_family_index + len(families) - prefix = _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - group_id=group_id, - parent_id=parent_id, - start=start, - end=end, - kind="prefix", - child_index=None, - ) - segment_cursor += 1 - completions: list[GdnSegmentSpec] = [] - while segment_cursor < len(starts): - child_group_id = segment_group_ids[segment_cursor] - child_parent_id = segment_parent_ids[segment_cursor] - child_start = starts[segment_cursor] - child_end = ends[segment_cursor] - if child_group_id == child_parent_id: - break - if child_parent_id != group_id: - raise ValueError( - f"row {row_index}: completion group {child_group_id} has " - f"parent {child_parent_id}, expected active prefix {group_id}" - ) - if child_group_id in seen_groups: - raise ValueError( - f"row {row_index}: group_id {child_group_id} is non-contiguous" - ) - seen_groups.add(child_group_id) - completions.append( - _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - group_id=child_group_id, - parent_id=child_parent_id, - start=child_start, - end=child_end, - kind="completion", - child_index=len(completions), - ) - ) - segment_cursor += 1 - if len(completions) < min_completions_per_family: - raise ValueError( - f"row {row_index}: prefix group {group_id} has {len(completions)} " - f"completion(s), expected at least {min_completions_per_family}" - ) - families.append( - _trusted_pydantic_construct( - GdnPackedFamilySpec, - _GDN_PACKED_FAMILY_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - prefix=prefix, - completions=tuple(completions), - ) - ) - return families - - -def _parse_row( - *, - row_index: int, - group_ids: list[int], - parent_ids: list[int], - valid_length: int, - first_family_index: int, - min_completions_per_family: int, -) -> list[GdnPackedFamilySpec]: - families: list[GdnPackedFamilySpec] = [] - seen_groups: set[int] = set() - cursor = 0 - while cursor < valid_length: - group_id, parent_id, start, end = _read_segment( - row_index, group_ids, parent_ids, valid_length, cursor - ) - if group_id in seen_groups: - raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") - if group_id != parent_id: - raise ValueError( - f"row {row_index}: completion group {group_id} appears before " - f"its prefix parent {parent_id}" - ) - seen_groups.add(group_id) - family_index = first_family_index + len(families) - prefix = GdnSegmentSpec( - row_index=row_index, - family_index=family_index, - group_id=group_id, - parent_id=parent_id, - start=start, - end=end, - kind="prefix", - ) - cursor = end - completions: list[GdnSegmentSpec] = [] - while cursor < valid_length: - child_group_id, child_parent_id, child_start, child_end = _read_segment( - row_index, group_ids, parent_ids, valid_length, cursor - ) - if child_group_id == child_parent_id: - break - if child_parent_id != group_id: - raise ValueError( - f"row {row_index}: completion group {child_group_id} has " - f"parent {child_parent_id}, expected active prefix {group_id}" - ) - if child_group_id in seen_groups: - raise ValueError( - f"row {row_index}: group_id {child_group_id} is non-contiguous" - ) - seen_groups.add(child_group_id) - completions.append( - GdnSegmentSpec( - row_index=row_index, - family_index=family_index, - group_id=child_group_id, - parent_id=child_parent_id, - start=child_start, - end=child_end, - kind="completion", - child_index=len(completions), - ) - ) - cursor = child_end - if len(completions) < min_completions_per_family: - raise ValueError( - f"row {row_index}: prefix group {group_id} has {len(completions)} " - f"completion(s), expected at least {min_completions_per_family}" - ) - families.append( - GdnPackedFamilySpec( - row_index=row_index, - family_index=family_index, - prefix=prefix, - completions=tuple(completions), - ) - ) - return families - - -def _read_segment( - row_index: int, - group_ids: list[int], - parent_ids: list[int], - valid_length: int, - cursor: int, -) -> tuple[int, int, int, int]: - group_id = int(group_ids[cursor]) - parent_id = int(parent_ids[cursor]) - if group_id < 0 or parent_id < 0: - raise ValueError(f"row {row_index}: segment ids must be non-negative") - start = cursor - cursor += 1 - while cursor < valid_length and int(group_ids[cursor]) == group_id: - current_parent = int(parent_ids[cursor]) - if current_parent != parent_id: - raise ValueError( - f"row {row_index}: group {group_id} changes parent from " - f"{parent_id} to {current_parent}" - ) - cursor += 1 - return group_id, parent_id, start, cursor diff --git a/src/art/megatron/gdn/layout.py b/src/art/megatron/gdn/layout.py index c3469a451..bd2ece79e 100644 --- a/src/art/megatron/gdn/layout.py +++ b/src/art/megatron/gdn/layout.py @@ -28,12 +28,18 @@ class GdnCpPeerTransfer(BaseModel): source_rank: int = Field(ge=0) dest_rank: int = Field(ge=0) token_count: int = Field(ge=0) + source_positions_cpu: tuple[int, ...] | None = None + dest_positions_cpu: tuple[int, ...] | None = None source_positions_tensor: Tensor | None = None dest_positions_tensor: Tensor | None = None @model_validator(mode="after") def _same_lengths(self) -> "GdnCpPeerTransfer": lengths = {int(self.token_count)} + if self.source_positions_cpu is not None: + lengths.add(len(self.source_positions_cpu)) + if self.dest_positions_cpu is not None: + lengths.add(len(self.dest_positions_cpu)) if self.source_positions_tensor is not None: lengths.add(int(self.source_positions_tensor.numel())) if self.dest_positions_tensor is not None: @@ -238,9 +244,13 @@ def _make_peer_transfer( source_count=source_count, dest_count=dest_count, ): + source_cpu = None + dest_cpu = None source_tensor = None dest_tensor = None else: + source_cpu = _tensor_positions_tuple(source_positions) + dest_cpu = _tensor_positions_tuple(dest_positions) target = torch.device(device) if device is not None else torch.device("cpu") source_tensor = source_positions.to( device=target, dtype=torch.long @@ -250,11 +260,17 @@ def _make_peer_transfer( source_rank=source_rank, dest_rank=dest_rank, token_count=token_count, + source_positions_cpu=source_cpu, + dest_positions_cpu=dest_cpu, source_positions_tensor=source_tensor, dest_positions_tensor=dest_tensor, ) +def _tensor_positions_tuple(tensor: Tensor) -> tuple[int, ...]: + return tuple(int(value) for value in tensor.detach().cpu().tolist()) + + def _is_full_identity_transfer( *, source_rank: int, @@ -287,6 +303,8 @@ def _reverse_exchange_plan(plan: GdnCpExchangePlan) -> GdnCpExchangePlan: source_rank=transfer.dest_rank, dest_rank=transfer.source_rank, token_count=_transfer_token_count(transfer), + source_positions_cpu=transfer.dest_positions_cpu, + dest_positions_cpu=transfer.source_positions_cpu, source_positions_tensor=transfer.dest_positions_tensor, dest_positions_tensor=transfer.source_positions_tensor, ) @@ -494,6 +512,8 @@ def move_cp_exchange_plan_to_device( source_rank=transfer.source_rank, dest_rank=transfer.dest_rank, token_count=transfer.token_count, + source_positions_cpu=transfer.source_positions_cpu, + dest_positions_cpu=transfer.dest_positions_cpu, source_positions_tensor=_move_optional_index_tensor( transfer.source_positions_tensor, target ), @@ -750,10 +770,15 @@ def _is_implicit_full_identity_transfer( ) -def _transfer_positions_tuple(tensor: Tensor | None) -> tuple[int, ...]: +def _transfer_positions_tuple( + positions: tuple[int, ...] | None, + tensor: Tensor | None, +) -> tuple[int, ...]: + if positions is not None: + return positions if tensor is None: return () - return tuple(int(value) for value in tensor.detach().cpu().tolist()) + return _tensor_positions_tuple(tensor) def _transfer_index_tensor( @@ -1028,7 +1053,10 @@ def _transfer_dest_positions_for_duplicate_check( dest_count=_dest_count_for_rank(plan, transfer.dest_rank), ): return tuple(range(token_count)) - positions = _transfer_positions_tuple(transfer.dest_positions_tensor) + positions = _transfer_positions_tuple( + transfer.dest_positions_cpu, + transfer.dest_positions_tensor, + ) if len(positions) != token_count: raise ValueError("GDN CP transfer destination positions must match token_count") return positions diff --git a/src/art/megatron/gdn/operator.py b/src/art/megatron/gdn/operator.py index e8a122f5c..c7c3aed96 100644 --- a/src/art/megatron/gdn/operator.py +++ b/src/art/megatron/gdn/operator.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import MethodType -from typing import Any, Callable, Literal, NamedTuple, Sequence, cast +from typing import Any, Callable, Iterable, Literal, NamedTuple, Sequence, cast import torch from torch import Tensor @@ -12,9 +12,9 @@ from .fla_cp import chunk_gated_delta_rule_native_cp from .gdn_shared_prefix import ( GdnPackedExecutionSpec, - GdnParentStateTransferPlan, GdnRankExecutionPlan, GdnSegmentBucketPlan, + GdnStateExchangePlan, build_gdn_rank_execution_plan, parse_gdn_shared_prefix_segments, ) @@ -518,23 +518,10 @@ def _run_planned_prefixes_and_completions( hidden_states: Tensor, plan: GdnRankExecutionPlan, ) -> tuple[Tensor, Tensor | None]: - if _has_chunk_aligned_local_plan(plan): - return _run_chunk_aligned_prefixes_and_completions(gdn, hidden_states, plan) - raise ValueError( - "shared-prefix GDN requires a chunk-aligned execution plan; " - "prefix/completion bucket execution has been removed" - ) - - -def _has_chunk_aligned_local_plan(plan: GdnRankExecutionPlan) -> bool: - return bool( - plan.prefix_boundary_buckets - or plan.prefix_tail_buckets - or plan.completion_with_prefix_tail_buckets - ) + return _run_tree_prefixes(gdn, hidden_states, plan) -def _run_chunk_aligned_prefixes_and_completions( +def _run_tree_prefixes( gdn: Any, hidden_states: Tensor, plan: GdnRankExecutionPlan, @@ -542,104 +529,382 @@ def _run_chunk_aligned_prefixes_and_completions( qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, hidden_states) gate = gate.clone() recurrent_output = torch.zeros_like(gate) - boundary_family_chunks: list[Tensor] = [] - boundary_conv_chunks: list[Tensor] = [] - boundary_rec_chunks: list[Tensor] = [] + recurrent_output, _cp_dependency = _run_tree_depth_buckets( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + plan, + state_reference=hidden_states, + ) + return _project_gdn_output(gdn, recurrent_output, gate, plan) + + +def _run_tree_depth_buckets( + gdn: Any, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + recurrent_output: Tensor, + plan: GdnRankExecutionPlan, + *, + state_reference: Tensor, + group: Any | None = None, + cp_dependency: Tensor | None = None, +) -> tuple[Tensor, Tensor | None]: + state_cache = _TreeStateChunkCache( + device=state_reference.device, + ) - for bucket in plan.prefix_boundary_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket + for depth, buckets in enumerate(plan.tree_segment_buckets_by_depth): + if depth < len(plan.tree_state_exchanges_by_depth): + cp_dependency = state_cache.exchange_remote_parent_states( + gdn, + plan.tree_state_exchanges_by_depth[depth], + state_reference=state_reference, + rank=plan.cp_rank, + group=group, + cp_dependency=cp_dependency, + ) + if depth < len(plan.tree_chain_buckets_by_depth): + for bucket in plan.tree_chain_buckets_by_depth[depth]: + recurrent_output, cp_dependency = _run_tree_bucket( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + state_cache, + bucket, + state_reference=state_reference, + group=group, + cp_dependency=cp_dependency, + recurrent_cp=True, + scale_parent_state_gradient=1.0 / plan.cp_size, + ) + + for bucket in buckets: + recurrent_output, cp_dependency = _run_tree_bucket( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + state_cache, + bucket, + state_reference=state_reference, + cp_dependency=cp_dependency, + ) + + return recurrent_output, cp_dependency + + +def _run_tree_bucket( + gdn: Any, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + recurrent_output: Tensor, + state_cache: "_TreeStateChunkCache", + bucket: GdnSegmentBucketPlan, + *, + state_reference: Tensor, + group: Any | None = None, + cp_dependency: Tensor | None = None, + recurrent_cp: bool = False, + scale_parent_state_gradient: float | None = None, +) -> tuple[Tensor, Tensor | None]: + parent_conv, parent_rec = state_cache.parent_states( + gdn, + bucket, + state_reference=state_reference, + ) + if _bucket_has_parent_state(bucket): + parent_conv, parent_rec = _couple_parent_states(parent_conv, parent_rec) + if scale_parent_state_gradient is not None: + parent_conv = _scale_state_gradient( + parent_conv, + scale_parent_state_gradient, + ) + parent_rec = _scale_state_gradient(parent_rec, scale_parent_state_gradient) + segment_qkv, segment_beta, segment_g = _gather_bucket_streams( + qkv, + beta, + recurrent_g, + bucket, + ) + if cp_dependency is not None: + segment_qkv = _add_autograd_dependency(segment_qkv, cp_dependency) + segment_beta = _add_autograd_dependency(segment_beta, cp_dependency) + segment_g = _add_autograd_dependency(segment_g, cp_dependency) + parent_conv = _add_autograd_dependency(parent_conv, cp_dependency) + parent_rec = _add_autograd_dependency(parent_rec, cp_dependency) + segment_out, segment_conv, segment_rec = run_gdn_bucket( + bucket, + (segment_qkv, segment_beta, segment_g), + (parent_conv, parent_rec), + gdn=gdn, + group=group, + recurrent_cp=recurrent_cp, + output_final_state=bucket.needs_final_state or recurrent_cp, + ) + if bucket.needs_final_state and (segment_conv is None or segment_rec is None): + raise RuntimeError("tree GDN execution must return final states") + if ( + bucket.needs_final_state + and segment_conv is not None + and segment_rec is not None + ): + cp_dependency = _make_autograd_dependency( + segment_out, segment_conv, segment_rec ) - zero_conv = _zero_conv_state( - gdn, hidden_states, batch_size=bucket.segment_count + else: + cp_dependency = _make_autograd_dependency(segment_out) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, + bucket, + segment_out, + ) + if bucket.needs_final_state: + state_cache.append( + bucket, + cast(Tensor, segment_conv), + cast(Tensor, segment_rec), + ) + return recurrent_output, cp_dependency + + +class _TreeStateChunkCache: + def __init__(self, *, device: torch.device) -> None: + self._device = device + self._conv_chunks: list[Tensor] = [] + self._rec_chunks: list[Tensor] = [] + self._source_by_family: dict[int, tuple[int, int]] = {} + + def append(self, bucket: GdnSegmentBucketPlan, conv: Tensor, rec: Tensor) -> None: + self.append_families(_bucket_family_indices_cpu(bucket), conv, rec) + + def append_families( + self, family_indices: Sequence[int], conv: Tensor, rec: Tensor + ) -> None: + if len(family_indices) == 0: + return + if int(conv.shape[0]) != len(family_indices): + raise ValueError( + "tree GDN state cache conv batch must match family count, got " + f"{tuple(conv.shape)} and {len(family_indices)} families" + ) + if int(rec.shape[0]) != len(family_indices): + raise ValueError( + "tree GDN state cache recurrent batch must match family count, got " + f"{tuple(rec.shape)} and {len(family_indices)} families" + ) + chunk_index = len(self._conv_chunks) + self._conv_chunks.append(conv) + self._rec_chunks.append(rec) + for source_row, family_index in enumerate(family_indices): + self._source_by_family[int(family_index)] = (chunk_index, source_row) + + def exchange_remote_parent_states( + self, + gdn: Any, + exchange: GdnStateExchangePlan | None, + *, + state_reference: Tensor, + rank: int, + group: Any | None, + cp_dependency: Tensor | None, + ) -> Tensor | None: + if exchange is None: + return cp_dependency + from .layout import exchange_rank_tensor_all_to_all + + source_conv, source_rec = self.states_for_families( + gdn, + exchange.source_family_indices, + state_reference=state_reference, + ) + if cp_dependency is not None: + source_conv = _add_autograd_dependency(source_conv, cp_dependency) + source_rec = _add_autograd_dependency(source_rec, cp_dependency) + remote_conv = exchange_rank_tensor_all_to_all( + source_conv, + exchange.exchange, + rank=rank, + group=group, + backward_plan=exchange.reverse_exchange, ) - zero_rec = _zero_recurrent_state( - gdn, hidden_states, batch_size=bucket.segment_count + remote_rec = exchange_rank_tensor_all_to_all( + source_rec, + exchange.exchange, + rank=rank, + group=group, + backward_plan=exchange.reverse_exchange, ) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("prefix boundary GDN execution must return final states") - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - boundary_family_chunks.append(bucket.family_indices) - boundary_conv_chunks.append(prefix_conv) - boundary_rec_chunks.append(prefix_rec) - - boundary_conv_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_conv_chunks, - zero_state=_zero_conv_state(gdn, hidden_states, batch_size=plan.family_count), - ) - boundary_rec_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_rec_chunks, - zero_state=_zero_recurrent_state( - gdn, hidden_states, batch_size=plan.family_count - ), - ) - - tail_family_chunks: list[Tensor] = [] - tail_conv_chunks: list[Tensor] = [] - tail_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) - tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, + self.append_families(exchange.dest_family_indices, remote_conv, remote_rec) + dependency = _make_zero_autograd_dependency( + source_conv, source_rec, remote_conv, remote_rec ) - if tail_conv is None or tail_rec is None: - raise RuntimeError("prefix tail GDN execution must return final states") - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, tail_out + return dependency if cp_dependency is None else dependency + cp_dependency + + def states_for_families( + self, + gdn: Any, + family_indices: Sequence[int], + *, + state_reference: Tensor, + ) -> tuple[Tensor, Tensor]: + if len(family_indices) == 0: + conv = _zero_conv_state(gdn, state_reference, batch_size=0) + rec = _zero_recurrent_state(gdn, state_reference, batch_size=0) + return conv.requires_grad_(True), rec.requires_grad_(True) + return self._mixed_parent_states( + gdn, + tuple(int(index) for index in family_indices), + state_reference=state_reference, + batch_size=len(family_indices), + roots_allowed=False, ) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_conv_table = _replace_indexed_family_states( - boundary_conv_table, - family_chunks=tail_family_chunks, - state_chunks=tail_conv_chunks, - ) - prefix_rec_table = _replace_indexed_family_states( - boundary_rec_table, - family_chunks=tail_family_chunks, - state_chunks=tail_rec_chunks, - ) + def parent_states( + self, + gdn: Any, + bucket: GdnSegmentBucketPlan, + *, + state_reference: Tensor, + ) -> tuple[Tensor, Tensor]: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = _bucket_parent_indices_cpu(bucket) + batch_size = bucket.segment_count + if all(parent_index < 0 for parent_index in parent_indices_cpu): + return ( + _zero_conv_state(gdn, state_reference, batch_size=batch_size), + _zero_recurrent_state(gdn, state_reference, batch_size=batch_size), + ) - for bucket in plan.completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, + return self._mixed_parent_states( + gdn, + parent_indices_cpu, + state_reference=state_reference, + batch_size=batch_size, ) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out + + def _mixed_parent_states( + self, + gdn: Any, + parent_indices_cpu: tuple[int, ...], + *, + state_reference: Tensor, + batch_size: int, + roots_allowed: bool = True, + ) -> tuple[Tensor, Tensor]: + sources_by_chunk: dict[int, list[tuple[int, int]]] = {} + missing_parents: list[int] = [] + for dest_row, parent_index in enumerate(parent_indices_cpu): + if parent_index < 0: + if roots_allowed: + continue + missing_parents.append(parent_index) + continue + source = self._source_by_family.get(parent_index) + if source is None: + missing_parents.append(parent_index) + continue + chunk_index, source_row = source + sources_by_chunk.setdefault(chunk_index, []).append((dest_row, source_row)) + if missing_parents: + raise RuntimeError( + "tree GDN append-only execution is missing parent state for " + f"families {tuple(missing_parents)}" + ) + + single_source_chunk = next(iter(sources_by_chunk.values())) + if len(sources_by_chunk) == 1 and len(single_source_chunk) == batch_size: + chunk_index, pairs = next(iter(sources_by_chunk.items())) + return ( + _select_state_rows(self._conv_chunks[chunk_index], pairs), + _select_state_rows(self._rec_chunks[chunk_index], pairs), + ) + + conv = _zero_conv_state(gdn, state_reference, batch_size=batch_size) + rec = _zero_recurrent_state(gdn, state_reference, batch_size=batch_size) + for chunk_index, pairs in sources_by_chunk.items(): + dest_rows = _long_tensor( + (dest_row for dest_row, _ in pairs), + device=self._device, + ) + source_rows = _long_tensor( + (source_row for _, source_row in pairs), + device=self._device, + ) + conv = conv.index_copy( + 0, + dest_rows, + self._conv_chunks[chunk_index].index_select(0, source_rows), + ) + rec = rec.index_copy( + 0, + dest_rows, + self._rec_chunks[chunk_index].index_select(0, source_rows), + ) + return conv, rec + + +def _select_state_rows(chunk: Tensor, pairs: Sequence[tuple[int, int]]) -> Tensor: + source_rows = tuple(source_row for _, source_row in pairs) + if len(set(source_rows)) == 1: + return chunk.narrow(0, source_rows[0], 1).expand( + len(source_rows), + *tuple(chunk.shape[1:]), ) - return _project_gdn_output(gdn, recurrent_output, gate, plan) + first_row = source_rows[0] + if source_rows == tuple(range(first_row, first_row + len(source_rows))): + return chunk.narrow(0, first_row, len(source_rows)) + return chunk.index_select( + 0, + _long_tensor(source_rows, device=chunk.device), + ) + + +def _bucket_family_indices_cpu(bucket: GdnSegmentBucketPlan) -> tuple[int, ...]: + family_indices = bucket.family_indices_cpu + if family_indices is None: + family_indices = bucket.family_indices.detach().cpu() + return tuple(int(index) for index in family_indices.tolist()) + + +def _bucket_parent_indices_cpu(bucket: GdnSegmentBucketPlan) -> tuple[int, ...]: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = bucket.parent_indices_cpu + if parent_indices_cpu is None: + parent_indices_cpu = parent_indices.detach().cpu() + return tuple(int(index) for index in parent_indices_cpu.tolist()) + + +def _long_tensor(values: Iterable[int], *, device: torch.device) -> Tensor: + return torch.tensor(tuple(values), dtype=torch.long, device=device) + + +def _bucket_has_parent_state(bucket: GdnSegmentBucketPlan) -> bool: + parent_indices_cpu = bucket.parent_indices_cpu + if parent_indices_cpu is None: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = parent_indices.detach().cpu() + return any(int(parent_index) >= 0 for parent_index in parent_indices_cpu.tolist()) + + +def _bucket_has_uniform_lengths(bucket: GdnSegmentBucketPlan) -> bool: + lengths_cpu = bucket.lengths_cpu + if lengths_cpu is None: + lengths_cpu = bucket.lengths.detach().cpu() + return all(int(length) == int(bucket.length) for length in lengths_cpu.tolist()) def _run_cp_planned_prefixes_and_completions( @@ -679,385 +944,21 @@ def _run_cp_planned_prefixes_and_completions( if empty_gdn_rank else _empty_autograd_dependency(qkv) ) - qkv_with_remote_tail = qkv - beta_with_remote_tail = beta - recurrent_g_with_remote_tail = recurrent_g - if plan.remote_prefix_tail_exchange is not None: - remote_qkv, remote_beta, remote_g = _exchange_remote_prefix_tail_streams( - qkv, - beta, - recurrent_g, - plan=plan, - group=group, - ) - qkv_with_remote_tail = torch.cat([qkv, remote_qkv.unsqueeze(0)], dim=1) - beta_with_remote_tail = torch.cat([beta, remote_beta.unsqueeze(0)], dim=1) - recurrent_g_with_remote_tail = torch.cat( - [recurrent_g, remote_g.unsqueeze(0)], dim=1 - ) - cp_dependency = cp_dependency + _make_zero_autograd_dependency( - remote_qkv, remote_beta, remote_g - ) + if not plan.tree_segment_buckets_by_depth: + raise ValueError("CP shared-prefix GDN requires a tree execution plan") gate = gate.clone() recurrent_output = torch.zeros_like(gate) - prefix_family_chunks: list[Tensor] = [] - prefix_conv_chunks: list[Tensor] = [] - prefix_rec_chunks: list[Tensor] = [] - - for bucket in plan.chain_prefix_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - group=group, - recurrent_cp=True, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("CP prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - cp_dependency = _make_autograd_dependency(prefix_out, prefix_conv, prefix_rec) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - boundary_family_chunks: list[Tensor] = [] - boundary_conv_chunks: list[Tensor] = [] - boundary_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_boundary_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("local prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - boundary_family_chunks.append(bucket.family_indices) - boundary_conv_chunks.append(prefix_conv) - boundary_rec_chunks.append(prefix_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - if ( - plan.prefix_tail_buckets - or plan.remote_prefix_tail_buckets - or plan.completion_with_prefix_tail_buckets - or plan.remote_completion_with_prefix_tail_buckets - or plan.remote_prefix_tail_state_transfers - ): - boundary_conv_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_conv_chunks, - zero_state=_zero_conv_state(gdn, qkv, batch_size=plan.family_count), - ) - boundary_rec_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_rec_chunks, - zero_state=_zero_recurrent_state(gdn, qkv, batch_size=plan.family_count), - ) - remote_boundary_conv_table = boundary_conv_table - remote_boundary_rec_table = boundary_rec_table - if plan.remote_prefix_tail_state_transfers: - ( - remote_boundary_conv_table, - remote_boundary_rec_table, - remote_boundary_dependency, - ) = _exchange_parent_state_rows( - boundary_conv_table, - boundary_rec_table, - transfers=plan.remote_prefix_tail_state_transfers, - group=group, - ) - cp_dependency = cp_dependency + remote_boundary_dependency - tail_family_chunks: list[Tensor] = [] - tail_conv_chunks: list[Tensor] = [] - tail_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) - tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, - ) - if tail_conv is None or tail_rec is None: - raise RuntimeError("local prefix tail GDN execution must return states") - tail_out = _add_autograd_dependency(tail_out, cp_dependency) - tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) - tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, tail_out - ) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(tail_conv) - prefix_rec_chunks.append(tail_rec) - for bucket in plan.remote_prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv_with_remote_tail, - beta_with_remote_tail, - recurrent_g_with_remote_tail, - bucket, - ) - tail_conv = remote_boundary_conv_table.index_select( - 0, bucket.family_indices - ) - tail_rec = remote_boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, - ) - if tail_conv is None or tail_rec is None: - raise RuntimeError( - "remote prefix tail GDN execution must return states" - ) - tail_out = _add_autograd_dependency(tail_out, cp_dependency) - tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) - tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(tail_conv) - prefix_rec_chunks.append(tail_rec) - prefix_conv_table = _replace_indexed_family_states( - boundary_conv_table, - family_chunks=tail_family_chunks, - state_chunks=tail_conv_chunks, - ) - prefix_rec_table = _replace_indexed_family_states( - boundary_rec_table, - family_chunks=tail_family_chunks, - state_chunks=tail_rec_chunks, - ) - for bucket in plan.completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - for bucket in plan.remote_completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, - beta, - recurrent_g, - bucket, - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - for bucket in plan.local_prefix_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("local prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - if not prefix_conv_chunks and not plan.parent_state_exchange_family_indices: - projected, out_bias = _project_cp_gdn_output( - gdn, - recurrent_output, - gate, - plan, - group=group, - output_layout=output_layout, - ) - projected = _add_autograd_dependency(projected, cp_dependency) - return projected, out_bias - - prefix_conv_table = _materialize_ordered_family_state_table( - family_chunks=prefix_family_chunks, - state_chunks=prefix_conv_chunks, - zero_state=_zero_conv_state(gdn, qkv, batch_size=plan.family_count), - ) - prefix_rec_table = _materialize_ordered_family_state_table( - family_chunks=prefix_family_chunks, - state_chunks=prefix_rec_chunks, - zero_state=_zero_recurrent_state(gdn, qkv, batch_size=plan.family_count), - ) - parent_state_exchanged = False - if plan.chain_completion_buckets and plan.parent_state_exchange_family_indices: - if not plan.parent_state_transfers: - raise ValueError("CP parent-state exchange requires planned transfers") - prefix_conv_table, prefix_rec_table, exchange_dependency = ( - _exchange_parent_state_rows( - prefix_conv_table, - prefix_rec_table, - transfers=plan.parent_state_transfers, - group=group, - ) - ) - cp_dependency = cp_dependency + exchange_dependency - parent_state_exchanged = True - for bucket in plan.chain_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_conv = _scale_state_gradient(completion_conv, 1.0 / plan.cp_size) - completion_rec = _scale_state_gradient(completion_rec, 1.0 / plan.cp_size) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - group=group, - recurrent_cp=True, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - cp_dependency = _make_autograd_dependency(completion_out) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - ready_completion_buckets = ( - plan.ready_local_completion_buckets - if plan.ready_local_completion_buckets or plan.remote_local_completion_buckets - else plan.local_completion_buckets + recurrent_output, cp_dependency = _run_tree_depth_buckets( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + plan, + state_reference=qkv, + group=group, + cp_dependency=cp_dependency, ) - for bucket in ready_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - if plan.parent_state_exchange_family_indices and not parent_state_exchanged: - if not plan.parent_state_transfers: - raise ValueError("CP parent-state exchange requires planned transfers") - prefix_conv_table, prefix_rec_table, exchange_dependency = ( - _exchange_parent_state_rows( - prefix_conv_table, - prefix_rec_table, - transfers=plan.parent_state_transfers, - group=group, - ) - ) - cp_dependency = cp_dependency + exchange_dependency - - for bucket in plan.remote_local_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - projected, out_bias = _project_cp_gdn_output( gdn, recurrent_output, @@ -1065,8 +966,8 @@ def _run_cp_planned_prefixes_and_completions( plan, group=group, output_layout=output_layout, + dependency=cp_dependency, ) - projected = _add_autograd_dependency(projected, cp_dependency) return projected, out_bias @@ -1922,6 +1823,7 @@ def _project_cp_gdn_output( *, group: Any, output_layout: Literal["attention", "gdn"], + dependency: Tensor | None = None, ) -> tuple[Tensor, Tensor | None]: batch_size, seq_len, _, _ = recurrent_output.shape token_uids = ( @@ -1933,6 +1835,8 @@ def _project_cp_gdn_output( norm_out = _apply_gated_rms_norm(gdn, recurrent_output, gate) norm_out = norm_out.reshape(batch_size, seq_len, _local_value_dim(gdn)) norm_out = norm_out.transpose(0, 1).contiguous() + if dependency is not None: + norm_out = _add_autograd_dependency(norm_out, dependency) if token_uids is not None: token_uids = _replicated_layout_token_uids(plan, "gdn", hidden_states=norm_out) _attach_trace_token_uids(norm_out, token_uids) @@ -2271,6 +2175,36 @@ def _local_value_dim(gdn: Any) -> int: return _local_value_heads(gdn) * int(gdn.value_head_dim) +def _prepare_dense_recurrent_inputs( + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + *, + key_heads: int, + value_heads: int, + key_dim: int, + value_dim: int, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + key_channels = int(key_heads) * int(key_dim) + value_channels = int(value_heads) * int(value_dim) + query = qkv[..., :key_channels].reshape(*qkv.shape[:2], key_heads, key_dim) + key = qkv[..., key_channels : 2 * key_channels].reshape( + *qkv.shape[:2], + key_heads, + key_dim, + ) + value = qkv[..., 2 * key_channels : 2 * key_channels + value_channels].reshape( + *qkv.shape[:2], + value_heads, + value_dim, + ) + repeat = int(value_heads) // int(key_heads) + if repeat != 1: + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + return query, key, value, beta, recurrent_g + + def _scatter_bucket_recurrent_output( output: Tensor, bucket: GdnSegmentBucketPlan, bucket_output: Tensor ) -> Tensor: @@ -2289,269 +2223,6 @@ def _bucket_output_mask(bucket: GdnSegmentBucketPlan) -> Tensor: return bucket.real_mask if output_mask is None else output_mask -def _materialize_indexed_family_state_table( - *, - plan: GdnRankExecutionPlan, - family_chunks: list[Tensor], - state_chunks: list[Tensor], - zero_state: Tensor, -) -> Tensor: - table = zero_state.detach() - if not state_chunks: - return table.requires_grad_(True) - values = torch.cat(state_chunks, dim=0) - family_indices = torch.cat(family_chunks, dim=0) - return table.index_copy(0, family_indices, values) - - -def _materialize_ordered_family_state_table( - *, - family_chunks: list[Tensor], - state_chunks: list[Tensor], - zero_state: Tensor, -) -> Tensor: - if len(family_chunks) != len(state_chunks): - raise RuntimeError("family and state chunk counts must match") - table = zero_state.detach().requires_grad_(True) - for family_indices, states in zip(family_chunks, state_chunks, strict=True): - table = table.index_copy(0, family_indices, states) - return table - - -def _replace_indexed_family_states( - table: Tensor, - *, - family_chunks: list[Tensor], - state_chunks: list[Tensor], -) -> Tensor: - if not state_chunks: - return table - return table.index_copy( - 0, - torch.cat(family_chunks, dim=0), - torch.cat(state_chunks, dim=0), - ) - - -def _exchange_parent_state_rows( - conv_table: Tensor, - rec_table: Tensor, - *, - transfers: tuple[GdnParentStateTransferPlan, ...], - group: Any, -) -> tuple[Tensor, Tensor, Tensor]: - if not transfers: - return conv_table, rec_table, _empty_autograd_dependency(conv_table) - conv_table, rec_table = _ParentStateExchange.apply( - conv_table, rec_table, transfers, group - ) - return conv_table, rec_table, _make_autograd_dependency(conv_table, rec_table) - - -def _exchange_remote_prefix_tail_streams( - qkv: Tensor, - beta: Tensor, - recurrent_g: Tensor, - *, - plan: GdnRankExecutionPlan, - group: Any, -) -> tuple[Tensor, Tensor, Tensor]: - from .layout import exchange_rank_tensor_all_to_all - - if plan.remote_prefix_tail_exchange is None: - return ( - qkv.new_empty((0, int(qkv.shape[-1]))), - beta.new_empty((0, int(beta.shape[-1]))), - recurrent_g.new_empty((0, int(recurrent_g.shape[-1]))), - ) - if plan.remote_prefix_tail_backward_exchange is None: - raise ValueError("remote prefix-tail exchange requires a backward plan") - qkv_flat = qkv.reshape(-1, int(qkv.shape[-1])) - beta_flat = beta.reshape(-1, int(beta.shape[-1])) - g_flat = recurrent_g.reshape(-1, int(recurrent_g.shape[-1])) - kwargs = { - "plan": plan.remote_prefix_tail_exchange, - "rank": plan.cp_rank, - "group": group, - "backward_plan": plan.remote_prefix_tail_backward_exchange, - } - return ( - exchange_rank_tensor_all_to_all(qkv_flat, **kwargs), - exchange_rank_tensor_all_to_all(beta_flat, **kwargs), - exchange_rank_tensor_all_to_all(g_flat, **kwargs), - ) - - -class _ParentStateExchange(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - conv_table: Tensor, - rec_table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - group: Any, - ) -> tuple[Tensor, Tensor]: - ctx.group = group - ctx.transfers = transfers - ctx.save_for_backward(conv_table, rec_table) - return ( - _exchange_parent_state_tensor_forward( - conv_table, - transfers, - group=group, - ), - _exchange_parent_state_tensor_forward( - rec_table, - transfers, - group=group, - ), - ) - - @staticmethod - def backward( - ctx: Any, *grad_outputs: Tensor | None - ) -> tuple[Tensor | None, Tensor | None, None, None]: - grad_conv, grad_rec = grad_outputs - conv_ref, rec_ref = ctx.saved_tensors - return ( - _exchange_parent_state_tensor_backward( - _zero_if_none(grad_conv, conv_ref), - ctx.transfers, - group=ctx.group, - ), - _exchange_parent_state_tensor_backward( - _zero_if_none(grad_rec, rec_ref), - ctx.transfers, - group=ctx.group, - ), - None, - None, - ) - - -def _exchange_parent_state_tensor_forward( - table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - group: Any, -) -> Tensor: - rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] - output = table.clone() - recvs = _exchange_parent_state_rows_all_to_all( - table, transfers, rank=rank, reverse=False, group=group - ) - for transfer, rows in recvs: - index = _parent_state_index_tensor(transfer, device=table.device) - output.index_copy_(0, index, rows) - return output - - -def _exchange_parent_state_tensor_backward( - grad_output: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - group: Any, -) -> Tensor: - rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] - grad_input = grad_output.clone() - for transfer in transfers: - if transfer.dest_rank != rank: - continue - index = _parent_state_index_tensor(transfer, device=grad_output.device) - grad_input.index_fill_(0, index, 0) - recvs = _exchange_parent_state_rows_all_to_all( - grad_output, transfers, rank=rank, reverse=True, group=group - ) - for transfer, rows in recvs: - index = _parent_state_index_tensor(transfer, device=grad_output.device) - grad_input.index_add_(0, index, rows) - return grad_input - - -def _zero_if_none(grad: Tensor | None, reference: Tensor) -> Tensor: - if grad is None: - return reference.new_zeros(reference.shape) - return grad.contiguous() - - -def _exchange_parent_state_rows_all_to_all( - table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - rank: int, - reverse: bool, - group: Any, -) -> list[tuple[GdnParentStateTransferPlan, Tensor]]: - world_size = torch.distributed.get_world_size(group) # ty: ignore[possibly-missing-attribute] - send_counts = [0 for _ in range(world_size)] - recv_counts = [0 for _ in range(world_size)] - send_pieces: list[Tensor] = [] - for peer_rank in range(world_size): - for transfer in transfers: - send_rank = transfer.dest_rank if reverse else transfer.source_rank - recv_rank = transfer.source_rank if reverse else transfer.dest_rank - if send_rank == recv_rank: - continue - row_count = len(transfer.family_indices) - if rank == send_rank and peer_rank == recv_rank: - index = _parent_state_index_tensor(transfer, device=table.device) - send_pieces.append(table.index_select(0, index).contiguous()) - send_counts[peer_rank] += row_count - if rank == recv_rank and peer_rank == send_rank: - recv_counts[peer_rank] += row_count - - trailing_shape = tuple(table.shape[1:]) - send_buffer = ( - torch.cat(send_pieces, dim=0) - if send_pieces - else table.new_empty((0, *trailing_shape)) - ) - recv_buffer = table.new_empty((sum(recv_counts), *trailing_shape)) - work = torch.distributed.all_to_all_single( # ty: ignore[possibly-missing-attribute] - recv_buffer, - send_buffer, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=group, - async_op=True, - ) - work.wait() - - recvs: list[tuple[GdnParentStateTransferPlan, Tensor]] = [] - offset = 0 - for peer_rank, count in enumerate(recv_counts): - peer_end = offset + count - for transfer in transfers: - send_rank = transfer.dest_rank if reverse else transfer.source_rank - recv_rank = transfer.source_rank if reverse else transfer.dest_rank - if send_rank == recv_rank: - continue - if rank != recv_rank or peer_rank != send_rank: - continue - rows = len(transfer.family_indices) - recvs.append((transfer, recv_buffer[offset : offset + rows])) - offset += rows - if offset != peer_end: - raise RuntimeError( - "parent-state exchange unpack mismatch: " - f"rank={rank} peer={peer_rank} consumed={offset} expected={peer_end}" - ) - return recvs - - -def _parent_state_index_tensor( - transfer: GdnParentStateTransferPlan, - *, - device: torch.device, -) -> Tensor: - if ( - transfer.family_indices_tensor is not None - and transfer.family_indices_tensor.device == device - ): - return transfer.family_indices_tensor - return torch.tensor(transfer.family_indices, device=device, dtype=torch.long) - - def run_gdn_bucket( bucket: GdnSegmentBucketPlan, projected_streams: tuple[Tensor, Tensor, Tensor], @@ -2597,14 +2268,17 @@ def run_gdn_bucket( conv_output_final_state = output_final_state chain_conv_final: Tensor | None = None + chain_gradient_dependency: Tensor | None = None if recurrent_cp: - conv_initial, chain_conv_final = _chain_conv_initial_and_final( - qkv, - bucket.cu_seqlens_cpu, - bucket.lengths_by_rank_cpu, - conv_initial, - group=group, - output_final_state=output_final_state, + conv_initial, chain_conv_final, chain_gradient_dependency = ( + _chain_conv_initial_and_final( + qkv, + bucket.cu_seqlens_cpu, + bucket.lengths_by_rank_cpu, + conv_initial, + group=group, + output_final_state=output_final_state, + ) ) conv_output_final_state = False @@ -2618,15 +2292,31 @@ def run_gdn_bucket( if recurrent_cp: conv_final = chain_conv_final - query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( - qkv, - beta, - recurrent_g, - key_heads=_local_key_heads(gdn), - value_heads=_local_value_heads(gdn), - key_dim=int(gdn.key_head_dim), - value_dim=int(gdn.value_head_dim), - ) + dense_local_bucket = not recurrent_cp and _bucket_has_uniform_lengths(bucket) + if dense_local_bucket: + query, key, value, beta, recurrent_g = _prepare_dense_recurrent_inputs( + qkv.reshape(batch_size, int(bucket.length), int(qkv.shape[-1])), + beta.reshape(batch_size, int(bucket.length), int(beta.shape[-1])), + recurrent_g.reshape( + batch_size, + int(bucket.length), + int(recurrent_g.shape[-1]), + ), + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) + else: + query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( + qkv, + beta, + recurrent_g, + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) if gdn.use_qk_l2norm: query = _l2norm(query.contiguous()) key = _l2norm(key.contiguous()) @@ -2657,8 +2347,27 @@ def run_gdn_bucket( initial_state=recurrent_initial, output_final_state=output_final_state, use_qk_l2norm_in_kernel=False, - cu_seqlens=bucket.cu_seqlens, - ) + cu_seqlens=None if dense_local_bucket else bucket.cu_seqlens, + ) + if dense_local_bucket: + recurrent_out = recurrent_out.reshape( + 1, + token_count, + int(recurrent_out.shape[-2]), + int(recurrent_out.shape[-1]), + ) + if chain_gradient_dependency is not None: + recurrent_out = _add_autograd_dependency( + recurrent_out, + chain_gradient_dependency, + ) + if conv_final is not None: + conv_final = _add_autograd_dependency(conv_final, chain_gradient_dependency) + if recurrent_final is not None: + recurrent_final = _add_autograd_dependency( + recurrent_final, + chain_gradient_dependency, + ) return recurrent_out, conv_final, recurrent_final @@ -2670,15 +2379,22 @@ def _chain_conv_initial_and_final( *, group: Any, output_final_state: bool, -) -> tuple[Tensor, Tensor | None]: +) -> tuple[Tensor, Tensor | None, Tensor]: if group is None: raise ValueError("CP chain conv state requires a process group") if not dist.is_available() or not dist.is_initialized(): # ty: ignore[possibly-missing-attribute] raise RuntimeError("torch.distributed must be initialized for CP chain conv") - parent_initial = _AllReduceGradient.apply(parent_initial, group) + parent_initial, gradient_dependency = _AllReduceGradient.apply( + parent_initial, + group, + ) tail_width = int(parent_initial.shape[-1]) if tail_width <= 0: - return parent_initial, parent_initial if output_final_state else None + return ( + parent_initial, + parent_initial if output_final_state else None, + gradient_dependency, + ) if lengths_by_rank_cpu is None: raise ValueError("CP chain conv requires static all-rank bucket lengths") if cu_seqlens_cpu.device.type != "cpu" or lengths_by_rank_cpu.device.type != "cpu": @@ -2705,7 +2421,7 @@ def _chain_conv_initial_and_final( if output_final_state else None ) - return conv_initial, conv_final + return conv_initial, conv_final, gradient_dependency def _local_packed_conv_tail( @@ -2782,14 +2498,20 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, None]: class _AllReduceGradient(torch.autograd.Function): @staticmethod - def forward(ctx: Any, tensor: Tensor, group: Any) -> Tensor: + def forward(ctx: Any, tensor: Tensor, group: Any) -> tuple[Tensor, Tensor]: ctx.group = group - return tensor + ctx.save_for_backward(tensor) + return tensor, tensor.new_zeros(()) @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, None]: - (grad_output,) = grad_outputs - grad_input = grad_output.contiguous() + def backward(ctx: Any, *grad_outputs: Tensor | None) -> tuple[Tensor, None]: + grad_output, _grad_dependency = grad_outputs + (reference,) = ctx.saved_tensors + grad_input = ( + reference.new_zeros(reference.shape) + if grad_output is None + else grad_output.contiguous() + ) dist.all_reduce( # ty: ignore[possibly-missing-attribute] grad_input, op=dist.ReduceOp.SUM, # ty: ignore[possibly-missing-attribute] diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 4cea46b2a..27fb2b30d 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1,9 +1,13 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +import contextvars +from dataclasses import dataclass +import functools import json import math import os import re -from typing import Any, Literal, NamedTuple, cast +from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.core import parallel_state as ps @@ -42,6 +46,8 @@ ShardDomain = Literal["tp", "expert_tp"] GradSyncDomain = Literal["tp_default", "expert_tp"] GradSyncOp = Literal["none", "sum", "avg"] +LoraSlotKind = Literal["checkpoint", "lora"] +_F = TypeVar("_F", bound=Callable[..., Any]) TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" @@ -50,6 +56,158 @@ GRAD_SYNC_OP_AVG: GradSyncOp = "avg" +@dataclass(frozen=True) +class LoRASlotRef: + kind: LoraSlotKind + name: str | None + + +@dataclass(frozen=True) +class _LoRASlotContext: + ref: LoRASlotRef + + +_CURRENT_LORA_SLOT: contextvars.ContextVar[_LoRASlotContext | None] = ( + contextvars.ContextVar("art_megatron_current_lora_slot", default=None) +) + + +def set_lora_slot_context( + ref: LoRASlotRef | None, +) -> contextvars.Token[_LoRASlotContext | None]: + """Select a dynamic LoRA slot for the current execution context. + + ``None`` preserves the legacy single-adapter path. ``LoRASlotRef(..., None)`` + explicitly selects the base model and makes every LoRA site an identity. + """ + + return _CURRENT_LORA_SLOT.set(None if ref is None else _LoRASlotContext(ref)) + + +def reset_lora_slot_context( + token: contextvars.Token[_LoRASlotContext | None], +) -> None: + _CURRENT_LORA_SLOT.reset(token) + + +@contextmanager +def use_lora_slot(ref: LoRASlotRef | None) -> Iterator[None]: + token = set_lora_slot_context(ref) + try: + yield + finally: + reset_lora_slot_context(token) + + +def _with_captured_lora_slot(function: _F) -> _F: + context = _CURRENT_LORA_SLOT.get() + + @functools.wraps(function) + def wrapped(*args: Any, **kwargs: Any) -> Any: + token = _CURRENT_LORA_SLOT.set(context) + try: + return function(*args, **kwargs) + finally: + _CURRENT_LORA_SLOT.reset(token) + + return cast(_F, wrapped) + + +def _patch_function_once(module: Any, name: str, wrapper: Callable[[_F], _F]) -> None: + original = getattr(module, name, None) + if original is None or getattr(original, "_art_lora_slot_context_patch", False): + return + patched = wrapper(original) + setattr(patched, "_art_lora_slot_context_patch", True) + setattr(module, name, patched) + + +def install_lora_checkpoint_context_hooks() -> None: + """Preserve the selected dynamic LoRA slot across activation recompute.""" + + def wrap_torch_checkpoint(original: _F) -> _F: + @functools.wraps(original) + def checkpoint(function: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + return original(_with_captured_lora_slot(function), *args, **kwargs) + + return cast(_F, checkpoint) + + def wrap_megatron_checkpoint(original: _F) -> _F: + @functools.wraps(original) + def checkpoint( + function: Callable[..., Any], + distribute_saved_activations: bool, + *args: Any, + ) -> Any: + return original( + _with_captured_lora_slot(function), + distribute_saved_activations, + *args, + ) + + return cast(_F, checkpoint) + + def wrap_checkpoint_without_output(original: _F) -> _F: + @functools.wraps(original) + def checkpoint(self: Any, function: Callable[..., Any], *args: Any) -> Any: + return original(self, _with_captured_lora_slot(function), *args) + + return cast(_F, checkpoint) + + def wrap_te_checkpoint(original: _F) -> _F: + @functools.wraps(original) + def checkpoint( + forward_func: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: + return original(_with_captured_lora_slot(forward_func), *args, **kwargs) + + return cast(_F, checkpoint) + + try: + import torch.utils.checkpoint as torch_checkpoint + + _patch_function_once(torch_checkpoint, "checkpoint", wrap_torch_checkpoint) + except Exception: + pass + + try: + import megatron.core.tensor_parallel as tensor_parallel + import megatron.core.tensor_parallel.random as megatron_random + + _patch_function_once(tensor_parallel, "checkpoint", wrap_megatron_checkpoint) + _patch_function_once(megatron_random, "checkpoint", wrap_megatron_checkpoint) + checkpoint_without_output = getattr( + megatron_random, "CheckpointWithoutOutput", None + ) + if checkpoint_without_output is not None: + _patch_function_once( + checkpoint_without_output, + "checkpoint", + wrap_checkpoint_without_output, + ) + except Exception: + pass + + try: + import megatron.core.transformer.transformer_block as transformer_block + + _patch_function_once(transformer_block, "te_checkpoint", wrap_te_checkpoint) + except Exception: + pass + + try: + import transformer_engine.pytorch.distributed as te_distributed + + _patch_function_once(te_distributed, "checkpoint", wrap_te_checkpoint) + except Exception: + pass + + +install_lora_checkpoint_context_hooks() + + class LoRAParallelSpec(BaseModel): # This spec only describes TP / expert-TP behavior. # DP/CP vs expert-DP behavior is selected separately via `allreduce`. @@ -307,6 +465,59 @@ def _exported_shard_dim(param: torch.nn.Parameter) -> int: return 1 - axis +def _copy_lora_param_metadata( + source: torch.nn.Parameter, + target: torch.nn.Parameter, +) -> None: + for name in ( + "lora_shard_domain", + "lora_tp_sharded", + "lora_tp_replicated", + "lora_tp_shard_dim", + "grad_sync_domain", + "grad_sync_op", + "allreduce", + "average_gradients_across_tp_domain", + "tensor_model_parallel", + "partition_dim", + "partition_stride", + "lora_tp_shard_strategy", + "lora_tp_component_sizes", + ): + if hasattr(source, name): + setattr(target, name, getattr(source, name)) + setattr(target, "_art_dynamic_lora_slot", True) + + +class LoRASlot(torch.nn.Module): + def __init__( + self, + *, + ref: LoRASlotRef, + a_t: torch.Tensor, + b_t: torch.Tensor, + alpha: float, + a_template: torch.nn.Parameter, + b_template: torch.nn.Parameter, + requires_grad: bool, + ) -> None: + super().__init__() + self.ref = ref + self.alpha = float(alpha) + self.A_T = torch.nn.Parameter(a_t.detach().clone(), requires_grad=requires_grad) + self.B_T = torch.nn.Parameter(b_t.detach().clone(), requires_grad=requires_grad) + _copy_lora_param_metadata(a_template, self.A_T) + _copy_lora_param_metadata(b_template, self.B_T) + + @property + def rank(self) -> int: + return int(self.A_T.shape[-1]) + + @property + def scale(self) -> float: + return self.alpha / self.rank + + class LoRA(torch.nn.Module): def __init__( self, @@ -327,7 +538,12 @@ def __init__( "adapter_model_prefix must contain the '{expert}' format placeholder if num_local_experts > 1" ) self.adapter_model_prefix = adapter_model_prefix + self.alpha = float(alpha) + self.in_features = int(in_features) + self.out_features = int(out_features) self.scale = alpha / rank + self._slot_modules = torch.nn.ModuleDict() + self._slot_keys: dict[LoRASlotRef, str] = {} self.A_T = torch.nn.Parameter( torch.zeros( num_local_experts, in_features, rank, dtype=dtype, device=device @@ -395,6 +611,86 @@ def _expected_weight_keys(self, suffix: str) -> list[str]: ] return [f"{self.adapter_model_prefix}.{suffix}.weight"] + def has_lora_slot(self, ref: LoRASlotRef) -> bool: + return ref in self._slot_keys + + def load_lora_slot( + self, + ref: LoRASlotRef, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float = LORA_ALPHA, + requires_grad: bool, + ) -> bool: + if ref.name is None: + raise ValueError("base-model slot refs do not own LoRA tensors") + keys = { + suffix: self._expected_weight_keys(suffix) + for suffix in ("lora_A", "lora_B") + } + present = { + suffix: [key in adapter_model for key in suffix_keys] + for suffix, suffix_keys in keys.items() + } + if not any(any(values) for values in present.values()): + return False + missing_keys = [ + key + for suffix, suffix_keys in keys.items() + for key, is_present in zip(suffix_keys, present[suffix], strict=True) + if not is_present + ] + if missing_keys: + raise KeyError( + f"Incomplete LoRA slot {ref.kind}:{ref.name} for " + f"{self.adapter_model_prefix}: {sorted(missing_keys)}" + ) + a_t = self._localized_weight( + self._adapter_weight(adapter_model, suffix="lora_A"), + into=self.A_T, + ) + b_t = self._localized_weight( + self._adapter_weight(adapter_model, suffix="lora_B"), + into=self.B_T, + ) + slot_key = self._slot_keys.get(ref) + if slot_key is None: + slot_key = f"slot_{len(self._slot_keys)}" + self._slot_keys[ref] = slot_key + elif self._has_live_slot_grads(ref): + raise RuntimeError( + f"Cannot overwrite live LoRA slot {ref.kind}:{ref.name} for " + f"{self.adapter_model_prefix}; clear grads/backward graph first." + ) + self._slot_modules[slot_key] = LoRASlot( + ref=ref, + a_t=a_t, + b_t=b_t, + alpha=alpha, + a_template=self.A_T, + b_template=self.B_T, + requires_grad=requires_grad, + ) + return True + + def lora_slot_params(self, ref: LoRASlotRef) -> list[torch.nn.Parameter]: + slot = self._slot(ref) + if slot is None: + return [] + return [slot.A_T, slot.B_T] + + def _slot(self, ref: LoRASlotRef) -> LoRASlot | None: + key = self._slot_keys.get(ref) + if key is None: + return None + return cast(LoRASlot, self._slot_modules[key]) + + def _has_live_slot_grads(self, ref: LoRASlotRef) -> bool: + slot = self._slot(ref) + return slot is not None and any( + param.grad is not None for param in (slot.A_T, slot.B_T) + ) + def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: missing_keys = [ key @@ -417,6 +713,17 @@ def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: into=self.B_T, ) + def _adapter_weight( + self, + adapter_model: dict[str, torch.Tensor], + *, + suffix: str, + ) -> torch.Tensor: + keys = self._expected_weight_keys(suffix) + if self.num_local_experts > 1: + return torch.stack([adapter_model[key].T for key in keys]) + return adapter_model[keys[0]].T + def load_weights( self, adapter_model: dict[str, torch.Tensor], @@ -424,14 +731,12 @@ def load_weights( suffix: str, into: torch.nn.Parameter, ) -> None: - keys = self._expected_weight_keys(suffix) - if self.num_local_experts > 1: - weight = torch.stack([adapter_model[key].T for key in keys]) - else: - weight = adapter_model[keys[0]].T + weight = self._adapter_weight(adapter_model, suffix=suffix) self.load_weight(weight, into=into) - def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + def _localized_weight( + self, weight: torch.Tensor, *, into: torch.nn.Parameter + ) -> torch.Tensor: domain = into.lora_shard_domain # ty: ignore[unresolved-attribute] if into.lora_tp_sharded: # ty: ignore[unresolved-attribute] axis = into.lora_tp_shard_dim # ty: ignore[unresolved-attribute] @@ -470,11 +775,10 @@ def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None raise ValueError( f"{self.adapter_model_prefix}: unsupported shard strategy={strategy}" ) - elif tuple(weight.shape) != tuple(into.shape): - raise ValueError( - f"{self.adapter_model_prefix}: unsharded load shape mismatch, got {tuple(weight.shape)} " - f"expected {tuple(into.shape)}" - ) + return weight.contiguous() + + def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + weight = self._localized_weight(weight, into=into) if tuple(weight.shape) != tuple(into.shape): raise ValueError( f"{self.adapter_model_prefix}: sharded load shape mismatch, got {tuple(weight.shape)} " @@ -575,9 +879,29 @@ def sharded_lora_grad_dict(self) -> dict[str, torch.Tensor]: grads[key] = local_grad.T return grads + def active_lora_tensors( + self, + ) -> tuple[torch.Tensor, torch.Tensor, float] | None: + context = _CURRENT_LORA_SLOT.get() + if context is None: + return self.A_T, self.B_T, self.scale + if context.ref.name is None: + return None + slot = self._slot(context.ref) + if slot is None: + return None + return slot.A_T, slot.B_T, slot.scale + + def _zero_output(self, x: torch.Tensor) -> torch.Tensor: + return x.new_zeros((*x.shape[:-1], self.out_features)) + def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None ) -> torch.Tensor: + active = self.active_lora_tensors() + if active is None: + return self._zero_output(x) + a_t, b_t, scale = active if tokens_per_expert is not None: assert self.num_local_experts > 1, ( "tokens_per_expert is only supported if num_local_experts > 1" @@ -586,12 +910,12 @@ def forward( if isinstance(bsz, list): bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") if x.shape[0] == 0: - return x.new_zeros((x.shape[0], self.B_T.shape[-1])) - return quack_grouped_lora(x, self.A_T, self.B_T, bsz, scale=self.scale) - out = (x @ self.A_T) @ self.B_T - if self.scale == 1.0: + return self._zero_output(x) + return quack_grouped_lora(x, a_t, b_t, bsz, scale=scale) + out = (x @ a_t) @ b_t + if scale == 1.0: return out - return out * self.scale + return out * scale class LoRAPublishPlanner: @@ -834,15 +1158,27 @@ def _expert_grouped_lora_dual_forward( counts = torch.tensor(counts, dtype=torch.int64, device="cpu") if x.shape[0] == 0: return x.new_zeros((x.shape[0], module.linear_fc1.out_features)) + gate = module.gate_lora.active_lora_tensors() + up = module.up_lora.active_lora_tensors() + if gate is None or up is None: + return torch.cat( + [ + module.gate_lora(x, tokens_per_expert=counts), + module.up_lora(x, tokens_per_expert=counts), + ], + dim=-1, + ) + gate_a_t, gate_b_t, gate_scale = gate + up_a_t, up_b_t, up_scale = up return quack_grouped_lora_dual( x, - module.gate_lora.A_T, - module.gate_lora.B_T, - module.up_lora.A_T, - module.up_lora.B_T, + gate_a_t, + gate_b_t, + up_a_t, + up_b_t, counts, - scale_gate=module.gate_lora.scale, - scale_up=module.up_lora.scale, + scale_gate=gate_scale, + scale_up=up_scale, ) @@ -1721,3 +2057,43 @@ def apply_lora_adapters( alpha=LORA_ALPHA, ) return list(model) + + +def load_lora_slot_into_model( + model: Sequence[torch.nn.Module], + ref: LoRASlotRef, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float = LORA_ALPHA, + requires_grad: bool, +) -> int: + loaded = 0 + for chunk in model: + for module in chunk.modules(): + if isinstance(module, LoRA) and module.load_lora_slot( + ref, + adapter_model, + alpha=alpha, + requires_grad=requires_grad, + ): + loaded += 1 + if loaded == 0 and ref.name is not None: + raise RuntimeError(f"LoRA slot {ref.kind}:{ref.name} loaded no adapter sites") + return loaded + + +def iter_lora_slot_parameters( + model: Sequence[torch.nn.Module], + ref: LoRASlotRef, +) -> Iterator[torch.nn.Parameter]: + seen: set[int] = set() + for chunk in model: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + for param in module.lora_slot_params(ref): + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + yield param diff --git a/src/art/megatron/model_support/spec.py b/src/art/megatron/model_support/spec.py index 15c6f8d96..92c1368a2 100644 --- a/src/art/megatron/model_support/spec.py +++ b/src/art/megatron/model_support/spec.py @@ -75,6 +75,7 @@ class ModelSupportSpec(BaseModel): class ModelSupportHandler(Protocol): key: str is_moe: bool + build_gdn_execution_spec: bool native_vllm_lora_status: NativeVllmLoraStatus def identity_lora_model_config(self, base_config: Any) -> Any: ... diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh index 6d3a5548c..3e5a1cb51 100755 --- a/src/art/megatron/setup.sh +++ b/src/art/megatron/setup.sh @@ -36,3 +36,7 @@ if [ -x "${HOME}/.local/bin/uv" ]; then uv_bin="${HOME}/.local/bin/uv" fi "${uv_bin}" sync --extra backend --extra megatron --frozen --active + +if [ "${INSTALL_VLLM_RUNTIME:-true}" = "true" ]; then + "${uv_bin}" sync --project vllm_runtime --frozen --no-dev +fi diff --git a/src/art/megatron/shared_prefix_packing.py b/src/art/megatron/shared_prefix_packing.py new file mode 100644 index 000000000..cbcaf6092 --- /dev/null +++ b/src/art/megatron/shared_prefix_packing.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class SharedPrefixPack: + tokens: torch.Tensor + group_ids: torch.Tensor + parent_ids: torch.Tensor + position_ids: torch.Tensor + positions_by_sequence: tuple[torch.Tensor, ...] + + +def pack_shared_prefixes( + sequences: Iterable[torch.Tensor], + *, + max_depth: int, +) -> SharedPrefixPack: + """Pack token sequences by storing shared prefixes once. + + This is the small packing step that lets `TrainerRank.forward()` run one + model pass over a compact prefix tree instead of replaying the same prompt + tokens for every request. Think of each input sequence as a path through a + tree: when several paths start with the same tokens, this function writes + that shared segment once, then writes each branch after it. + + Args: + sequences: 1-D token tensors to pack. + max_depth: How many nested shared-prefix levels to emit. `0` disables + prefix sharing and writes each sequence as its own root segment. `1` + shares the first common segment in each branch; larger values allow + branches to contain shared sub-branches. + + Returns: + `tokens` is the compact model input, shaped `[1, packed_length]`. + `group_ids` and `parent_ids` describe the prefix tree to shared-prefix + attention. Positions in the same emitted segment share a group, and each + group points at the parent segment it continues from. Root groups point + to themselves. + `position_ids` keeps each token's original sequence position for + positional embeddings/rotary attention. + `positions_by_sequence` is the reverse index used after the model call + to unpack logits, logprobs, or hidden states back into one tensor per + original request. + + The implementation is a tiny radix-tree walk. It finds the longest prefix + shared by the active sequences, emits that segment once, then partitions the + remaining sequences by their next token while preserving first-seen order. + Single sequences, empty branches, and branches past `max_depth` are emitted + as ordinary unshared tails. + """ + if max_depth < 0: + raise ValueError("max_depth must be >= 0") + + tensors = tuple(_sequence_tensor(sequence) for sequence in sequences) + if not tensors: + return _empty_pack() + + device = tensors[0].device + lengths = torch.tensor([len(tensor) for tensor in tensors], device=device) + if int(lengths.max().item()) == 0: + return _empty_pack(len(tensors), device=device) + + padded = torch.nn.utils.rnn.pad_sequence(list(tensors), batch_first=True) + token_chunks: list[torch.Tensor] = [] + group_chunks: list[torch.Tensor] = [] + parent_chunks: list[torch.Tensor] = [] + position_chunks: list[torch.Tensor] = [] + positions_by_sequence: list[list[torch.Tensor]] = [[] for _ in tensors] + cursor = 0 + next_group_id = 1 + + def emit( + indices: torch.Tensor, + start: int, + end: int, + parent_group_id: int | None, + ) -> int: + nonlocal cursor, next_group_id + segment = tensors[int(indices[0].item())][start:end] + group_id = next_group_id + next_group_id += 1 + parent_id = group_id if parent_group_id is None else parent_group_id + packed_positions = torch.arange(cursor, cursor + len(segment), device=device) + + token_chunks.append(segment) + group_chunks.append(torch.full_like(segment, group_id)) + parent_chunks.append(torch.full_like(segment, parent_id)) + position_chunks.append(torch.arange(start, end, device=device)) + for sequence_index in indices.tolist(): + positions_by_sequence[sequence_index].append(packed_positions) + cursor += len(segment) + return group_id + + def shared_end(indices: torch.Tensor, start: int) -> int: + end = int(lengths.index_select(0, indices).min().item()) + if start >= end: + return start + shared = ( + padded.index_select(0, indices)[:, start:end] + == padded[indices[0], start:end] + ).all(dim=0) + return ( + end + if bool(shared.all().item()) + else start + int(shared.logical_not().nonzero()[0]) + ) + + def branch_groups(indices: torch.Tensor, start: int) -> list[torch.Tensor]: + groups: dict[int, list[int]] = {} + order: list[int] = [] + symbols = padded.index_select(0, indices)[:, start].tolist() + for symbol, index in zip(symbols, indices.tolist(), strict=True): + if symbol not in groups: + groups[symbol] = [] + order.append(symbol) + groups[symbol].append(index) + return [ + torch.tensor(groups[symbol], dtype=torch.long, device=device) + for symbol in order + ] + + def walk( + indices: torch.Tensor, + start: int, + parent_group_id: int | None, + depth: int, + ) -> None: + active = indices[lengths.index_select(0, indices) > start] + if int(active.numel()) == 0: + return + if ( + max_depth == 0 + or int(active.numel()) == 1 + or (parent_group_id is not None and depth >= max_depth) + ): + for sequence_index in active: + emit( + sequence_index[None], + start, + int(lengths[sequence_index].item()), + parent_group_id, + ) + return + + end = shared_end(active, start) + if end > start: + group_id = emit(active, start, end, parent_group_id) + walk(active, end, group_id, depth + 1) + return + + for group in branch_groups(active, start): + walk(group, start, parent_group_id, depth) + + walk(torch.arange(len(tensors), device=device), 0, None, 0) + + return SharedPrefixPack( + tokens=torch.cat(token_chunks).unsqueeze(0), + group_ids=torch.cat(group_chunks).unsqueeze(0), + parent_ids=torch.cat(parent_chunks).unsqueeze(0), + position_ids=torch.cat(position_chunks).unsqueeze(0), + positions_by_sequence=tuple( + torch.cat(chunks) + if chunks + else torch.empty(0, dtype=torch.long, device=device) + for chunks in positions_by_sequence + ), + ) + + +def visualize_shared_prefix_pack(pack: SharedPrefixPack) -> str: + rows = ["pos token group parent source_pos"] + for position, (token, group, parent, source_pos) in enumerate( + zip( + pack.tokens.reshape(-1).detach().cpu().tolist(), + pack.group_ids.reshape(-1).detach().cpu().tolist(), + pack.parent_ids.reshape(-1).detach().cpu().tolist(), + pack.position_ids.reshape(-1).detach().cpu().tolist(), + strict=True, + ) + ): + rows.append(f"{position:>3} {token:>5} {group:>5} {parent:>6} {source_pos:>10}") + for index, positions in enumerate(pack.positions_by_sequence): + rows.append(f"seq {index}: {positions.detach().cpu().tolist()}") + return "\n".join(rows) + + +def _empty_pack( + sequence_count: int = 0, + *, + device: torch.device | None = None, +) -> SharedPrefixPack: + flat = torch.empty(0, dtype=torch.long, device=device) + row = flat.unsqueeze(0) + return SharedPrefixPack( + tokens=row, + group_ids=row, + parent_ids=row, + position_ids=row, + positions_by_sequence=tuple(flat for _ in range(sequence_count)), + ) + + +def _sequence_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.ndim != 1: + raise ValueError( + f"pack_shared_prefixes expects 1-D tensors, got {tuple(tensor.shape)}" + ) + return tensor.detach().to(dtype=torch.long).contiguous() diff --git a/src/art/megatron/shared_prefix_state.py b/src/art/megatron/shared_prefix_state.py index 7bbda4624..4221a3e0d 100644 --- a/src/art/megatron/shared_prefix_state.py +++ b/src/art/megatron/shared_prefix_state.py @@ -118,30 +118,101 @@ def _build_sparse_shared_prefix_block_mask( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, ) - row_spec = batch_spec.rows[0] seq_len = int(group_ids_cpu.shape[1]) - slices = _full_row_slices_with_padding( - row_slices=row_spec.slices, - valid_tokens=int(row_spec.valid_tokens), + row_masks = [] + token_indices = torch.arange(seq_len, dtype=torch.int64) + for row_spec in batch_spec.rows: + row_index = int(row_spec.row_index) + slices = _row_local_slices( + _full_row_slices_with_padding( + row_slices=row_spec.slices, + valid_tokens=int(row_spec.valid_tokens), + seq_len=seq_len, + ) + ) + if not slices: + row_masks.append( + _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) + ) + continue + row_masks.append( + build_block_mask( + FlexMaskSpec( + q_len=seq_len, + k_len=seq_len, + block_size=block_size, + slices=slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key=f"identity:{seq_len}", + ), + ), + group_ids=group_ids_cpu[row_index], + parent_ids=parent_ids_cpu[row_index], + device=device, + ) + ) + if not row_masks: + return _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) + return _stack_row_block_masks( + row_masks, seq_len=seq_len, + block_size=block_size, ) - if not slices: - return _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) - return build_block_mask( - FlexMaskSpec( - q_len=seq_len, - k_len=seq_len, - block_size=block_size, - slices=slices, - exact_mask=ExactMaskMetadata( - q_token_indices=torch.arange(seq_len, dtype=torch.int64), - k_token_indices=torch.arange(seq_len, dtype=torch.int64), - cache_key=f"identity:{seq_len}", - ), - ), - group_ids=group_ids_cpu[0], - parent_ids=parent_ids_cpu[0], - device=device, + + +def _row_local_slices(slices: tuple[AttnSlice, ...]) -> tuple[AttnSlice, ...]: + return tuple(slice_.model_copy(update={"row_index": 0}) for slice_ in slices) + + +def _stack_optional_block_tensors( + masks: list[BlockMask], + name: str, +) -> Tensor | None: + tensors = [getattr(mask, name) for mask in masks] + if any(tensor is None for tensor in tensors): + return None + return torch.cat(tensors, dim=0) + + +def _stack_row_block_masks( + masks: list[BlockMask], + *, + seq_len: int, + block_size: tuple[int, int], +) -> BlockMask: + if len(masks) == 1: + return masks[0] + row_mask_mods = tuple(mask.mask_mod for mask in masks) + + def mask_mod( + batch_idx: Tensor, + head_idx: Tensor, + query_idx: Tensor, + kv_idx: Tensor, + ) -> Tensor: + result = torch.zeros_like(query_idx, dtype=torch.bool) + for row_index, row_mask_mod in enumerate(row_mask_mods): + result = torch.where( + batch_idx == row_index, + row_mask_mod(batch_idx, head_idx, query_idx, kv_idx), + result, + ) + return result + + return BlockMask( + seq_lengths=(int(seq_len), int(seq_len)), + kv_num_blocks=torch.cat([mask.kv_num_blocks for mask in masks], dim=0), + kv_indices=torch.cat([mask.kv_indices for mask in masks], dim=0), + full_kv_num_blocks=_stack_optional_block_tensors(masks, "full_kv_num_blocks"), + full_kv_indices=_stack_optional_block_tensors(masks, "full_kv_indices"), + q_num_blocks=_stack_optional_block_tensors(masks, "q_num_blocks"), + q_indices=_stack_optional_block_tensors(masks, "q_indices"), + full_q_num_blocks=_stack_optional_block_tensors(masks, "full_q_num_blocks"), + full_q_indices=_stack_optional_block_tensors(masks, "full_q_indices"), + BLOCK_SIZE=block_size, + mask_mod=mask_mod, ) @@ -232,18 +303,9 @@ def _build_gdn_execution_spec_once( cp_size: int, cp_group: Any | None, ) -> GdnPackedExecutionSpec | None: + del cp_rank, cp_size, cp_group if not build: return None - if cp_size == 1: - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - if ( - not torch.distributed.is_available() or not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] - ): - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) return parse_gdn_shared_prefix_segments( group_ids, parent_ids, min_completions_per_family=0 ) diff --git a/src/art/megatron/shared_prefix_tree.py b/src/art/megatron/shared_prefix_tree.py new file mode 100644 index 000000000..6d68ed10b --- /dev/null +++ b/src/art/megatron/shared_prefix_tree.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True, slots=True) +class SharedPrefixSegment: + row_index: int + run_index: int + group_id: int + parent_id: int + start: int + end: int + family_index: int + root_group_id: int + ancestors: tuple[int, ...] + + @property + def depth(self) -> int: + return len(self.ancestors) + + @property + def length(self) -> int: + return self.end - self.start + + +@dataclass(frozen=True, slots=True) +class SharedPrefixRowTree: + row_index: int + valid_tokens: int + segments: tuple[SharedPrefixSegment, ...] + + @property + def max_depth(self) -> int: + return max((segment.depth for segment in self.segments), default=0) + + @property + def is_flat_family_tree(self) -> bool: + return self.max_depth <= 1 + + def segment_by_group_id(self) -> dict[int, SharedPrefixSegment]: + segments: dict[int, SharedPrefixSegment] = {} + for segment in self.segments: + segments.setdefault(segment.group_id, segment) + return segments + + def group_can_attend_matrix( + self, + ) -> tuple[tuple[int, ...], tuple[tuple[bool, ...], ...]]: + group_ids = tuple(sorted({segment.group_id for segment in self.segments})) + group_index = {group_id: index + 1 for index, group_id in enumerate(group_ids)} + matrix = [ + [False for _ in range(len(group_ids) + 1)] + for _ in range(len(group_ids) + 1) + ] + for segment in self.segments: + query_index = group_index[segment.group_id] + for group_id in (*segment.ancestors, segment.group_id): + key_index = group_index.get(group_id) + if key_index is not None: + matrix[query_index][key_index] = True + return group_ids, tuple(tuple(row) for row in matrix) + + +def parse_shared_prefix_tree( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + ignore_padding_group_id: int = -1, + require_contiguous_group_runs: bool = True, +) -> tuple[SharedPrefixRowTree, ...]: + if group_ids.shape != parent_ids.shape: + raise RuntimeError( + "group_ids and parent_ids must share shape, got " + f"{tuple(group_ids.shape)} vs {tuple(parent_ids.shape)}" + ) + if group_ids.ndim != 2: + raise RuntimeError( + "group_ids and parent_ids must be rank-2 packed tensors, got " + f"{group_ids.ndim}" + ) + return tuple( + parse_shared_prefix_row( + group_ids=group_ids[row_index], + parent_ids=parent_ids[row_index], + row_index=row_index, + ignore_padding_group_id=ignore_padding_group_id, + require_contiguous_group_runs=require_contiguous_group_runs, + ) + for row_index in range(int(group_ids.shape[0])) + ) + + +def parse_shared_prefix_row( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + row_index: int = 0, + ignore_padding_group_id: int = -1, + require_contiguous_group_runs: bool = True, +) -> SharedPrefixRowTree: + if group_ids.shape != parent_ids.shape: + raise RuntimeError( + "group_ids and parent_ids must share shape, got " + f"{tuple(group_ids.shape)} vs {tuple(parent_ids.shape)}" + ) + if group_ids.ndim != 1: + raise RuntimeError( + f"group_ids and parent_ids must be rank-1 row tensors, got {group_ids.ndim}" + ) + + valid_tokens = _valid_length( + group_ids, + parent_ids, + ignore_padding_group_id=ignore_padding_group_id, + ) + if valid_tokens == 0: + return SharedPrefixRowTree(row_index=row_index, valid_tokens=0, segments=()) + + runs = _scan_runs(group_ids[:valid_tokens], parent_ids[:valid_tokens]) + group_run_count: dict[int, int] = {} + first_segment_by_group: dict[int, SharedPrefixSegment] = {} + family_by_group: dict[int, int] = {} + root_by_group: dict[int, int] = {} + ancestors_by_group: dict[int, tuple[int, ...]] = {} + segments: list[SharedPrefixSegment] = [] + next_family_index = 0 + + for _start, _end, group_id, _parent_id in runs: + group_run_count[group_id] = group_run_count.get(group_id, 0) + 1 + if require_contiguous_group_runs: + repeated_groups = { + group_id: count + for group_id, count in group_run_count.items() + if count > 1 and group_id != ignore_padding_group_id + } + if repeated_groups: + raise RuntimeError( + "Shared-prefix metadata requires contiguous group runs per row, " + f"found repeats in row {row_index}: {repeated_groups}" + ) + + for run_index, (start, end, group_id, parent_id) in enumerate(runs): + prior_segment = first_segment_by_group.get(group_id) + if prior_segment is not None: + segment = SharedPrefixSegment( + row_index=row_index, + run_index=run_index, + group_id=group_id, + parent_id=parent_id, + start=start, + end=end, + family_index=prior_segment.family_index, + root_group_id=prior_segment.root_group_id, + ancestors=prior_segment.ancestors, + ) + segments.append(segment) + continue + + is_root = group_id == parent_id or ( + start == 0 and parent_id == ignore_padding_group_id + ) + if is_root: + family_index = next_family_index + next_family_index += 1 + root_group_id = group_id + ancestors: tuple[int, ...] = () + else: + parent_segment = first_segment_by_group.get(parent_id) + if parent_segment is None: + raise RuntimeError( + "Shared-prefix run points to a missing parent run: " + f"row={row_index}, group_id={group_id}, parent_id={parent_id}" + ) + if int(parent_segment.end) > int(start): + raise RuntimeError( + "Shared-prefix parent run must end before its child starts: " + f"row={row_index}, group_id={group_id}, parent_id={parent_id}" + ) + family_index = family_by_group[parent_id] + root_group_id = root_by_group[parent_id] + ancestors = (*ancestors_by_group[parent_id], parent_id) + + segment = SharedPrefixSegment( + row_index=row_index, + run_index=run_index, + group_id=group_id, + parent_id=parent_id, + start=start, + end=end, + family_index=family_index, + root_group_id=root_group_id, + ancestors=ancestors, + ) + first_segment_by_group[group_id] = segment + family_by_group[group_id] = family_index + root_by_group[group_id] = root_group_id + ancestors_by_group[group_id] = ancestors + segments.append(segment) + + return SharedPrefixRowTree( + row_index=row_index, + valid_tokens=valid_tokens, + segments=tuple(segments), + ) + + +def max_shared_prefix_tree_depth( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + ignore_padding_group_id: int = -1, +) -> int: + return max( + ( + row.max_depth + for row in parse_shared_prefix_tree( + group_ids=group_ids, + parent_ids=parent_ids, + ignore_padding_group_id=ignore_padding_group_id, + ) + ), + default=0, + ) + + +def _valid_length( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + *, + ignore_padding_group_id: int, +) -> int: + valid_mask = group_ids != ignore_padding_group_id + valid_count = int(valid_mask.sum().item()) + if valid_count == 0: + return 0 + if not bool(valid_mask[:valid_count].all().item()): + raise RuntimeError("Padding tokens must be a contiguous tail") + return _infer_terminal_padding_length( + group_ids[:valid_count], + parent_ids[:valid_count], + ) + + +def _infer_terminal_padding_length( + group_row: torch.Tensor, + parent_row: torch.Tensor, +) -> int: + if group_row.numel() == 0: + return 0 + runs = _scan_runs(group_row, parent_row) + if len(runs) < 2: + return int(group_row.numel()) + last_start, _last_end, last_group_id, last_parent_id = runs[-1] + if last_parent_id >= 0: + return int(group_row.numel()) + terminal_pair = (last_group_id, last_parent_id) + if any( + (group_id, parent_id) == terminal_pair + for _start, _end, group_id, parent_id in runs[:-1] + ): + return last_start + return int(group_row.numel()) + + +def _scan_runs( + group_row: torch.Tensor, + parent_row: torch.Tensor, +) -> list[tuple[int, int, int, int]]: + length = int(group_row.numel()) + if length == 0: + return [] + + group_changes = group_row[1:] != group_row[:-1] + parent_changes = parent_row[1:] != parent_row[:-1] + inconsistent_parent = torch.nonzero( + torch.logical_not(group_changes) & parent_changes, + as_tuple=False, + ).flatten() + if int(inconsistent_parent.numel()) > 0: + mismatch_index = int(inconsistent_parent[0].item()) + 1 + prior_boundaries = torch.nonzero( + group_changes[: mismatch_index - 1], + as_tuple=False, + ).flatten() + start = ( + 0 + if int(prior_boundaries.numel()) == 0 + else int(prior_boundaries[-1].item()) + 1 + ) + group_id = int(group_row[start].item()) + raise RuntimeError( + "Found one group run with inconsistent parent ids: " + f"group_id={group_id}, start={start}, end={mismatch_index}" + ) + + run_starts = torch.cat( + ( + torch.zeros(1, dtype=torch.int64, device=group_row.device), + torch.nonzero(group_changes, as_tuple=False).flatten() + 1, + ) + ) + run_ends = torch.cat( + ( + run_starts[1:], + torch.tensor([length], dtype=torch.int64, device=group_row.device), + ) + ) + starts = run_starts.to(device="cpu").tolist() + ends = run_ends.to(device="cpu").tolist() + group_ids = group_row.index_select(0, run_starts).to(device="cpu").tolist() + parent_ids = parent_row.index_select(0, run_starts).to(device="cpu").tolist() + return [ + (int(start), int(end), int(group_id), int(parent_id)) + for start, end, group_id, parent_id in zip( + starts, ends, group_ids, parent_ids, strict=True + ) + ] diff --git a/src/art/megatron/trainer_rank.py b/src/art/megatron/trainer_rank.py new file mode 100644 index 000000000..a2d8ce87d --- /dev/null +++ b/src/art/megatron/trainer_rank.py @@ -0,0 +1,2461 @@ +from __future__ import annotations + +from collections.abc import Callable, Iterable, Iterator, MutableMapping, Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import zip_longest +import os +from typing import TYPE_CHECKING, Generic, Literal, ParamSpec, TypeVar, cast, overload + +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +import torch.distributed as dist + +from art.megatron.shared_prefix_packing import pack_shared_prefixes + +if TYPE_CHECKING: + from megatron.bridge.models.gpt_provider import GPTModelProvider + from megatron.core.models.gpt.gpt_model import GPTModel + from megatron.core.optimizer import MegatronOptimizer, OptimizerConfig + from megatron.core.packed_seq_params import PackedSeqParams + + from art.megatron.context_parallel.types import ( + ArtContextParallelState, + ParallelTopology, + ) + from art.megatron.lora import LoRASlotRef + from art.megatron.model_support import ModelSupportHandler + from art.megatron.shared_prefix_state import SharedPrefixAttentionState + from art.megatron.train import TrainingRuntime + + +@dataclass(frozen=True) +class AdamParams: + learning_rate: float + beta1: float = 0.9 + beta2: float = 0.99 + weight_decay: float = 0.1 + grad_clip_norm: float = 0.1 + + +@dataclass(frozen=True) +class TopK: + logprobs: torch.Tensor + tokens: torch.Tensor + + +LogprobsT = TypeVar("LogprobsT", bound=torch.Tensor | None, covariant=True) +TopKT = TypeVar("TopKT", bound=TopK | None, covariant=True) +LogitsT = TypeVar("LogitsT", bound=torch.Tensor | None, covariant=True) +HiddenStatesT = TypeVar("HiddenStatesT", bound=torch.Tensor | None, covariant=True) +T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") + +_COMPILED_FUNCTIONS: dict[Callable[..., object], Callable[..., object]] = {} + + +class _Unset: + def __repr__(self) -> str: + return "Unset" + + +Unset = _Unset() +type AdapterSelection = str | None | _Unset + + +@dataclass(frozen=True) +class ForwardOutput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): + target_logprobs: LogprobsT + top_k: TopKT + logits: LogitsT + hidden_states: HiddenStatesT + + +class ForwardInput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): + def __init__( + self, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> None: + if top_k is not None and top_k < 1: + raise ValueError("top_k must be >= 1") + if checkpoint is not Unset and lora is not Unset: + raise ValueError("ForwardInput cannot set both checkpoint and lora") + self.input_tokens = input_tokens + self.target_tokens = target_tokens + self.top_k = top_k + self.logits = logits + self.hidden_states = hidden_states + self.checkpoint = checkpoint + self.lora = lora + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + ) -> "ForwardInput[None, None, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + ) -> "ForwardInput[torch.Tensor, None, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + ) -> "ForwardInput[None, TopK, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[False] = False, + ) -> "ForwardInput[None, None, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[True], + ) -> "ForwardInput[None, None, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + ) -> "ForwardInput[torch.Tensor, TopK, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[False] = False, + ) -> "ForwardInput[torch.Tensor, None, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[True], + ) -> "ForwardInput[torch.Tensor, None, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[True], + hidden_states: Literal[False] = False, + ) -> "ForwardInput[None, TopK, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[True], + ) -> "ForwardInput[None, TopK, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[True], + ) -> "ForwardInput[None, None, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[True], + hidden_states: Literal[False] = False, + ) -> "ForwardInput[torch.Tensor, TopK, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[True], + ) -> "ForwardInput[torch.Tensor, TopK, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[True], + ) -> "ForwardInput[torch.Tensor, None, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[True], + hidden_states: Literal[True], + ) -> "ForwardInput[None, TopK, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[True], + hidden_states: Literal[True], + ) -> "ForwardInput[torch.Tensor, TopK, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "AnyForwardInput": ... + + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "AnyForwardInput": + return super().__new__(cls) + + +type AnyForwardInput = ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, +] +type AnyForwardOutput = ForwardOutput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, +] +type ForwardInputs = AnyForwardInput | Iterable["ForwardInputs"] +type ForwardOutputs = AnyForwardOutput | Sequence["ForwardOutputs"] +ForwardInputsT = TypeVar("ForwardInputsT", bound=ForwardInputs) + + +@dataclass(frozen=True) +class MicroBatch(Generic[ForwardInputsT]): + inputs: Sequence[ForwardInputsT] + indices: Sequence[int] + + def select(self, xs: Sequence[T]) -> Sequence[T]: + return [xs[i] for i in self.indices] + + +@dataclass(frozen=True) +class _PushedSlot: + trainer: "TrainerRank" + ref: "LoRASlotRef" + + def __enter__(self) -> "_PushedSlot": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: object, + ) -> bool: + if not self.trainer._slot_stack or self.trainer._slot_stack[-1] != self.ref: + raise RuntimeError( + "Pushed LoRA/checkpoint stack changed before context exit" + ) + self.trainer.pop_pushed_lora_or_checkpoint() + return False + + +@dataclass(frozen=True) +class _ForwardItem: + request: AnyForwardInput + input_ids: torch.Tensor + labels: torch.Tensor | None + + +@dataclass(frozen=True) +class _PackedForwardBatch: + tokens: torch.Tensor + group_ids: torch.Tensor + parent_ids: torch.Tensor + position_ids: torch.Tensor + positions_by_item: tuple[torch.Tensor, ...] + + +@dataclass(frozen=True) +class _PreparedPackedForward: + tokens: torch.Tensor + position_ids: torch.Tensor + attention_state: "SharedPrefixAttentionState | ArtContextParallelState" + packed_seq_params: "PackedSeqParams | None" + positions_by_item: tuple[torch.Tensor, ...] + source_positions_by_item: tuple[torch.Tensor, ...] + + +@dataclass(frozen=True) +class _HeadOutputs: + target_logprobs: list[torch.Tensor | None] + top_k: list[TopK | None] + logits: list[torch.Tensor | None] + + +@dataclass(frozen=True) +class _RowMatch: + source_offsets: torch.Tensor + row_offsets: torch.Tensor + + +class TrainerRank: + def __init__( + self, + runtime: TrainingRuntime, + *, + micro_batch_size: int = 1, + head_chunk_tokens: int = 512, + shared_prefix_max_depth: int = 1, + ) -> None: + if micro_batch_size < 1: + raise ValueError("micro_batch_size must be >= 1") + if head_chunk_tokens < 1: + raise ValueError("head_chunk_tokens must be >= 1") + if shared_prefix_max_depth < 0: + raise ValueError("shared_prefix_max_depth must be >= 0") + self.runtime: TrainingRuntime = runtime + self.micro_batch_size = micro_batch_size + self.head_chunk_tokens = head_chunk_tokens + self.shared_prefix_max_depth = shared_prefix_max_depth + self.device = next(runtime.model[0].parameters()).device + self._default_slot_ref: LoRASlotRef | None = None + self._slot_stack: list[LoRASlotRef] = [] + self._dynamic_optimizers: dict[str, torch.optim.Optimizer] = {} + self._checkpoint_slot_names: set[str] = set() + self.zero_grad() + + def zero_grad(self) -> None: + for chunk in self.runtime.model: + zero_grad_buffer = getattr(chunk, "zero_grad_buffer", None) + if callable(zero_grad_buffer): + zero_grad_buffer() + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is not None: + optimizer.zero_grad() + for name in self._checkpoint_slot_names: + for param in self._checkpoint_slot_params(name): + param.grad = None + + def _optimizer(self) -> "MegatronOptimizer": + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is None: + raise RuntimeError("TrainerRank requires a runtime with an optimizer") + return optimizer + + def _handler(self) -> "ModelSupportHandler": + return cast("ModelSupportHandler", self.runtime.model_support_handler) + + def _provider(self) -> "GPTModelProvider": + return cast("GPTModelProvider", self.runtime.provider) + + def set_checkpoint(self, name: str | None) -> None: + self._set_default_slot(self._slot_ref("checkpoint", name)) + + def set_lora(self, name: str | None) -> None: + self._set_default_slot(self._slot_ref("lora", name)) + + def push_checkpoint(self, name: str | None) -> _PushedSlot: + ref = self._slot_ref("checkpoint", name) + self._slot_stack.append(ref) + return _PushedSlot(self, ref) + + def push_lora(self, name: str | None) -> _PushedSlot: + ref = self._slot_ref("lora", name) + self._slot_stack.append(ref) + return _PushedSlot(self, ref) + + def pop_pushed_lora_or_checkpoint(self) -> None: + if not self._slot_stack: + raise RuntimeError("No pushed LoRA or checkpoint to pop") + self._slot_stack.pop() + + def load_checkpoint_slot( + self, + name: str, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float | None = None, + ) -> int: + loaded = self._load_slot( + "checkpoint", name, adapter_model, trainable=True, alpha=alpha + ) + self._validate_dynamic_slot_consistency("checkpoint", name, loaded) + self._checkpoint_slot_names.add(name) + return loaded + + def load_lora_slot( + self, + name: str, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float | None = None, + ) -> int: + loaded = self._load_slot( + "lora", name, adapter_model, trainable=False, alpha=alpha + ) + self._validate_dynamic_slot_consistency("lora", name, loaded) + return loaded + + def _load_slot( + self, + kind: Literal["checkpoint", "lora"], + name: str, + adapter_model: dict[str, torch.Tensor], + *, + trainable: bool, + alpha: float | None, + ) -> int: + from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model + + return load_lora_slot_into_model( + self.runtime.model, + self._slot_ref(kind, name), + adapter_model, + alpha=LORA_ALPHA if alpha is None else alpha, + requires_grad=trainable, + ) + + def _set_default_slot(self, ref: "LoRASlotRef") -> None: + if self._slot_stack: + raise RuntimeError("Cannot set a LoRA/checkpoint while a slot is pushed") + self._default_slot_ref = ref + + @staticmethod + def _slot_ref( + kind: Literal["checkpoint", "lora"], name: str | None + ) -> "LoRASlotRef": + from art.megatron.lora import LoRASlotRef + + return LoRASlotRef(kind=kind, name=name) + + def _validate_dynamic_slot_consistency( + self, + kind: Literal["checkpoint", "lora"], + name: str, + loaded_sites: int, + ) -> None: + if not (dist.is_available() and dist.is_initialized()): + return + + from art.megatron.lora import iter_lora_slot_parameters + + ref = self._slot_ref(kind, name) + params = list(iter_lora_slot_parameters(self.runtime.model, ref)) + local = { + "rank": dist.get_rank(), + "loaded_sites": int(loaded_sites), + "param_count": len(params), + "numel": sum(int(param.numel()) for param in params), + "signature": [ + ( + tuple(int(dim) for dim in param.shape), + str(param.dtype), + bool(getattr(param, "allreduce", True)), + str(getattr(param, "grad_sync_domain", "tp_default")), + str(getattr(param, "grad_sync_op", "none")), + ) + for param in params + ], + } + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + ranks = [rank for rank in gathered if rank is not None] + reference = ranks[0] + mismatched = [ + rank + for rank in ranks + if rank["loaded_sites"] != reference["loaded_sites"] + or rank["signature"] != reference["signature"] + ] + if not mismatched: + return + + first_mismatch = None + for left, right in zip_longest( + cast(list[object], reference["signature"]), + cast(list[object], mismatched[0]["signature"]), + fillvalue=None, + ): + if left != right: + first_mismatch = {"expected": left, "actual": right} + break + summary = [ + { + "rank": rank["rank"], + "loaded_sites": rank["loaded_sites"], + "param_count": rank["param_count"], + "numel": rank["numel"], + } + for rank in ranks + ] + raise RuntimeError( + f"Dynamic LoRA slot {kind}:{name} is not loaded consistently across " + "distributed ranks. This usually means a sharded/exported LoRA state " + "dict was passed directly to TrainerRank; gather or materialize the " + "full adapter state before loading a dynamic slot. " + f"Rank summary: {summary}. First mismatch: {first_mismatch}." + ) + + def _resolve_slot_ref(self, request: AnyForwardInput) -> "LoRASlotRef | None": + if request.checkpoint is not Unset: + return self._slot_ref("checkpoint", cast(str | None, request.checkpoint)) + if request.lora is not Unset: + return self._slot_ref("lora", cast(str | None, request.lora)) + if self._slot_stack: + return self._slot_stack[-1] + return self._default_slot_ref + + def _set_current_slot(self, ref: "LoRASlotRef | None") -> object: + from art.megatron.lora import set_lora_slot_context + + return set_lora_slot_context(ref) + + def _reset_current_slot(self, token: object) -> None: + from art.megatron.lora import reset_lora_slot_context + + reset_lora_slot_context(token) # type: ignore[arg-type] + + @contextmanager + def _use_slot(self, ref: "LoRASlotRef | None") -> Iterator[None]: + token = self._set_current_slot(ref) + try: + yield + finally: + self._reset_current_slot(token) + + def micro_batches( + self, + inputs: Iterable[ForwardInputsT], + ) -> Sequence[MicroBatch[ForwardInputsT]]: + items = list(inputs) + from megatron.core import parallel_state as ps + + dp_rank = int(ps.get_data_parallel_rank()) + dp_size = int(ps.get_data_parallel_world_size()) + global_micro_size = self.micro_batch_size * dp_size + batches: list[MicroBatch[ForwardInputsT]] = [] + for start in range(0, len(items), global_micro_size): + stop = min(start + global_micro_size, len(items)) + indices = list(range(start + dp_rank, stop, dp_size)) + batches.append(MicroBatch([items[i] for i in indices], indices)) + return batches + + @overload + def forward( + self, + inputs: Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]], + ) -> Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]: ... + + @overload + def forward( + self, + inputs: Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ], + ) -> Sequence[ + Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ]: ... + + @overload + def forward( + self, + inputs: Iterable[ + Iterable[Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ], + ) -> Sequence[ + Sequence[Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ]: ... + + @overload + def forward( + self, + inputs: Iterable[ + Iterable[ + Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ] + ] + ], + ) -> Sequence[ + Sequence[ + Sequence[Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ] + ]: ... + + def forward(self, inputs: ForwardInputs) -> ForwardOutputs: + materialized = _materialize(inputs) + outputs = iter(self._forward_flat(list(_flatten(materialized)))) + return _unflatten(materialized, outputs) + + def dp_reduce( + self, + tensor: torch.Tensor, + *, + op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, + ) -> None: + from megatron.core import parallel_state as ps + + dist.all_reduce( + tensor, + op=op, + group=ps.get_data_parallel_group(with_context_parallel=True), + ) + + def optim_step( + self, + *, + params: AdamParams, + scale_grads: float = 1.0, + checkpoints: Sequence[str] | None = None, + ) -> dict[str, float]: + selected_checkpoints = self._selected_dynamic_checkpoints(checkpoints) + if selected_checkpoints: + return self._dynamic_optim_step( + selected_checkpoints, + params=params, + scale_grads=scale_grads, + ) + + from art.megatron.training.finalize_grads import ( + finalize_model_grads_extended, + flush_param_grads_to_main_grads, + ) + from art.megatron.training.model_chunks import as_megatron_api_chunks + + optimizer = self._optimizer() + flush_param_grads_to_main_grads(self.runtime.model) + finalize_model_grads_extended( + as_megatron_api_chunks(self.runtime.model), + num_tokens=None, + ) + self._scale_main_grads(scale_grads) + self._configure_optimizer(params) + update_successful, grad_norm, num_zeros = optimizer.step() + optimizer.zero_grad() + self.zero_grad() + return { + "learning_rate": float(params.learning_rate), + "grad_norm": float(grad_norm), + "update_successful": float(bool(update_successful)), + "num_zeros_in_grad": float(num_zeros or 0), + } + + def _selected_dynamic_checkpoints( + self, + checkpoints: Sequence[str] | None, + ) -> tuple[str, ...]: + if checkpoints is not None: + unknown = set(checkpoints) - self._checkpoint_slot_names + if unknown: + raise ValueError(f"Unknown checkpoint slots: {sorted(unknown)}") + return tuple(dict.fromkeys(checkpoints)) + names = [] + for name in sorted(self._checkpoint_slot_names): + local_has_grad = any( + param.grad is not None for param in self._checkpoint_slot_params(name) + ) + has_grad = torch.tensor( + int(local_has_grad), + device=self.device, + dtype=torch.int32, + ) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(has_grad, op=dist.ReduceOp.MAX) + if bool(has_grad.item()): + names.append(name) + return tuple(names) + + def _dynamic_optim_step( + self, + checkpoint_names: Sequence[str], + *, + params: AdamParams, + scale_grads: float, + ) -> dict[str, float]: + all_params: list[torch.nn.Parameter] = [] + for name in checkpoint_names: + slot_params = self._checkpoint_slot_params(name) + self._ensure_dynamic_grads(slot_params) + self._reduce_dynamic_grads(slot_params) + if scale_grads != 1.0: + for param in slot_params: + if param.grad is not None: + param.grad.mul_(scale_grads) + all_params.extend(slot_params) + + grad_norm = torch.nn.utils.clip_grad_norm_( + all_params, + max_norm=params.grad_clip_norm, + ) + for name in checkpoint_names: + optimizer = self._dynamic_optimizer(name, params) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + return { + "learning_rate": float(params.learning_rate), + "grad_norm": float(grad_norm), + "update_successful": 1.0, + "num_zeros_in_grad": 0.0, + } + + def _dynamic_optimizer( + self, + name: str, + params: AdamParams, + ) -> torch.optim.Optimizer: + optimizer = self._dynamic_optimizers.get(name) + if optimizer is None: + optimizer = torch.optim.AdamW( + self._checkpoint_slot_params(name), + lr=params.learning_rate, + betas=(params.beta1, params.beta2), + weight_decay=params.weight_decay, + ) + self._dynamic_optimizers[name] = optimizer + return optimizer + for group in optimizer.param_groups: + group["lr"] = params.learning_rate + group["betas"] = (params.beta1, params.beta2) + group["weight_decay"] = params.weight_decay + return optimizer + + def _checkpoint_slot_params(self, name: str) -> list[torch.nn.Parameter]: + from art.megatron.lora import iter_lora_slot_parameters + + return list( + iter_lora_slot_parameters( + self.runtime.model, + self._slot_ref("checkpoint", name), + ) + ) + + @staticmethod + def _ensure_dynamic_grads(params: Sequence[torch.nn.Parameter]) -> None: + for param in params: + if param.grad is None: + param.grad = torch.zeros_like(param) + + def _reduce_dynamic_grads(self, params: Sequence[torch.nn.Parameter]) -> None: + from megatron.core import parallel_state as ps + + buckets: list[ + tuple[ + object, + dist.ReduceOp.RedOpType, + torch.dtype, + torch.device, + list[torch.Tensor], + ] + ] = [] + + def add_to_bucket( + *, + group: object, + op: dist.ReduceOp.RedOpType, + grad: torch.Tensor, + ) -> None: + for ( + bucket_group, + bucket_op, + bucket_dtype, + bucket_device, + bucket_grads, + ) in buckets: + if ( + bucket_group is group + and bucket_op == op + and bucket_dtype == grad.dtype + and bucket_device == grad.device + ): + bucket_grads.append(grad) + return + buckets.append((group, op, grad.dtype, grad.device, [grad])) + + for param in params: + grad = param.grad + if grad is None: + continue + if bool(getattr(param, "allreduce", True)): + group = ps.get_data_parallel_group(with_context_parallel=True) + else: + group = ps.get_expert_data_parallel_group() + if group is not None and group.size() > 1: + add_to_bucket(group=group, op=dist.ReduceOp.SUM, grad=grad) + + op = getattr(param, "grad_sync_op", "none") + if op == "none": + continue + domain = getattr(param, "grad_sync_domain", "tp_default") + if domain == "expert_tp": + tp_group = ps.get_expert_tensor_parallel_group(check_initialized=False) + else: + tp_group = ps.get_tensor_model_parallel_group(check_initialized=False) + if tp_group is None or tp_group.size() <= 1: + continue + reduce_op = dist.ReduceOp.AVG if op == "avg" else dist.ReduceOp.SUM + add_to_bucket(group=tp_group, op=reduce_op, grad=grad) + + for group, op, _dtype, _device, grads in buckets: + self._coalesced_all_reduce(grads, group=group, op=op) + + @staticmethod + def _coalesced_all_reduce( + grads: Sequence[torch.Tensor], + *, + group: object, + op: dist.ReduceOp.RedOpType, + ) -> None: + if not grads: + return + coalesced = _flatten_dense_tensors(grads) + reduced = ( + coalesced.float() + if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 + else coalesced + ) + dist.all_reduce(reduced, op=op, group=group) + if reduced is not coalesced: + reduced = reduced.to(dtype=coalesced.dtype) + for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): + grad.copy_(synced) + + def _forward_flat( + self, requests: Sequence[AnyForwardInput] + ) -> list[AnyForwardOutput]: + outputs = [ + ForwardOutput( + target_logprobs=None, + top_k=None, + logits=None, + hidden_states=None, + ) + for _ in requests + ] + active_indices = [ + index + for index, request in enumerate(requests) + if request.target_tokens is not None + or request.logits + or request.top_k is not None + or request.hidden_states + ] + if not active_indices: + return outputs + + groups: dict[LoRASlotRef | None, list[int]] = {} + for index in active_indices: + groups.setdefault(self._resolve_slot_ref(requests[index]), []).append(index) + + for slot_ref, group_indices in groups.items(): + items = [self._forward_item(requests[index]) for index in group_indices] + packed = _pack_forward_items(items, max_depth=self.shared_prefix_max_depth) + with self._use_slot(slot_ref): + prepared = self._prepare_packed_forward(packed) + item_outputs = self._forward_packed(items, prepared) + for index, output in zip(group_indices, item_outputs, strict=True): + outputs[index] = output + return outputs + + def _forward_item(self, request: AnyForwardInput) -> _ForwardItem: + _validate_top_k(request.top_k, _language_model(self.runtime.model[0])) + input_ids = _as_1d_long(request.input_tokens, name="input_tokens") + labels = ( + _as_target_tokens(request.target_tokens, request.input_tokens, input_ids) + if request.target_tokens is not None + else None + ) + return _ForwardItem(request=request, input_ids=input_ids, labels=labels) + + def _forward_packed( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + ) -> list[AnyForwardOutput]: + if _is_native_target_only(items): + labels = self._consistent_packed_labels(items, prepared) + if labels is not None: + return self._forward_native_target_logprobs(items, prepared, labels) + + hidden_by_row = self._gather_sequence_parallel_hidden( + self._decoder_hidden(prepared) + ) + head_outputs = self._project_head(items, prepared, hidden_by_row) + outputs: list[AnyForwardOutput] = [] + for index, (item, positions) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ): + hidden_states = ( + _select_positions(hidden_by_row, positions) + if item.request.hidden_states + else None + ) + outputs.append( + ForwardOutput( + target_logprobs=head_outputs.target_logprobs[index], + top_k=head_outputs.top_k[index], + logits=head_outputs.logits[index], + hidden_states=hidden_states, + ) + ) + return outputs + + def _forward_native_target_logprobs( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + labels: torch.Tensor, + ) -> list[AnyForwardOutput]: + from art.megatron.train import _placeholder_attention_mask + + per_token_loss = self.runtime.model[0]( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + attention_mask=_placeholder_attention_mask(self.device), + labels=labels, + packed_seq_params=prepared.packed_seq_params, + **self._handler().get_forward_kwargs( + self.runtime.model[0], + attention_bias=prepared.attention_state, + ), + ) + flat_logprobs = -per_token_loss.reshape(-1) + outputs: list[AnyForwardOutput] = [] + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + raise RuntimeError("native target path requires labels") + item_labels = item.labels.to(device=self.device).index_select( + 0, + source_positions.to(device=self.device), + ) + target_logprobs = _select_positions(flat_logprobs, positions).masked_fill( + item_labels == -100, + 0.0, + ) + outputs.append( + ForwardOutput( + target_logprobs=target_logprobs, + top_k=None, + logits=None, + hidden_states=None, + ) + ) + return outputs + + def _consistent_packed_labels( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + ) -> torch.Tensor | None: + labels = torch.full_like(prepared.tokens, -100) + flat_labels = labels.reshape(-1) + has_label = torch.zeros_like(flat_labels, dtype=torch.bool) + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + continue + item_positions = positions.to(device=labels.device) + item_labels = item.labels.to(device=labels.device).index_select( + 0, + source_positions.to(device=labels.device), + ) + keep = item_labels != -100 + if not bool(keep.any().item()): + continue + kept_positions = item_positions[keep] + kept_labels = item_labels[keep] + existing = flat_labels.index_select(0, kept_positions) + seen = has_label.index_select(0, kept_positions) + if bool(((existing != kept_labels) & seen).any().item()): + return None + flat_labels.index_copy_(0, kept_positions, kept_labels) + has_label.index_fill_(0, kept_positions, True) + return labels + + def _decoder_hidden( + self, + prepared: _PreparedPackedForward, + ) -> torch.Tensor: + from art.megatron.train import _placeholder_attention_mask + + handler = self._handler() + model = _language_model(self.runtime.model[0]) + attention_mask = _placeholder_attention_mask(self.device) + forward_kwargs = handler.get_forward_kwargs( + self.runtime.model[0], + attention_bias=prepared.attention_state, + ) + extra_block_kwargs = cast( + dict[str, object] | None, + forward_kwargs.pop("extra_block_kwargs", None), + ) + preprocessed = model._preprocess( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + packed_seq_params=cast("PackedSeqParams", prepared.packed_seq_params), + ) + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = preprocessed[:6] + rotary_pos_cos_sin = preprocessed[6] if len(preprocessed) == 7 else None + return cast( + torch.Tensor, + model.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + packed_seq_params=prepared.packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + **(extra_block_kwargs or {}), + ), + ) + + def _project_head( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + hidden_by_row: torch.Tensor, + ) -> "_HeadOutputs": + model = _language_model(self.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + device = hidden_by_row.device + target_logprobs = [None for _ in items] + logits: list[torch.Tensor | None] = [None for _ in items] + top_k: list[TopK | None] = [None for _ in items] + label_rows: list[torch.Tensor | None] = [None for _ in items] + full_rows: list[torch.Tensor] = [] + local_rows: list[torch.Tensor] = [] + + for index, (item, positions_cpu) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ): + positions = positions_cpu.to(device=device) + if item.request.logits: + full_rows.append(positions) + elif item.request.top_k is not None: + local_rows.append(positions) + if item.labels is not None: + source_positions = prepared.source_positions_by_item[index].to(device) + labels = item.labels.to(device=device).index_select(0, source_positions) + label_rows[index] = labels + target_logprobs[index] = torch.zeros( + tuple(labels.shape), + device=device, + dtype=torch.float32, + ) + if item.request.top_k is None and not item.request.logits: + valid_offsets = _valid_target_offsets(labels) + if int(valid_offsets.numel()): + local_rows.append(positions.index_select(0, valid_offsets)) + if item.request.logits: + logits[index] = _empty_logits_like_positions( + positions, + model, + hidden_by_row, + ) + + full_row_tensor = ( + torch.cat(full_rows).unique(sorted=True) + if full_rows + else torch.empty(0, dtype=torch.long, device=device) + ) + local_row_tensor = ( + torch.cat(local_rows).unique(sorted=True) + if local_rows + else torch.empty(0, dtype=torch.long, device=device) + ) + if int(full_row_tensor.numel()) and int(local_row_tensor.numel()): + local_row_tensor = local_row_tensor[ + ~torch.isin(local_row_tensor, full_row_tensor) + ] + + if int(full_row_tensor.numel()): + self._project_full_logits( + items, + prepared, + hidden_by_row, + full_row_tensor, + output_weight=output_weight, + target_logprobs=target_logprobs, + top_k=top_k, + logits=logits, + label_rows=label_rows, + ) + + if int(local_row_tensor.numel()): + local_row_matches = _row_matches_by_item( + prepared.positions_by_item, + local_row_tensor, + device=device, + ) + self._project_vocab_parallel( + items, + hidden_by_row, + local_row_tensor, + row_matches=local_row_matches, + item_lengths=tuple( + int(positions.numel()) for positions in prepared.positions_by_item + ), + output_weight=output_weight, + target_logprobs=target_logprobs, + top_k=top_k, + label_rows=label_rows, + ) + + return _HeadOutputs(target_logprobs, top_k, logits) + + def _project_full_logits( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + hidden_by_row: torch.Tensor, + rows: torch.Tensor, + *, + output_weight: torch.Tensor | None, + target_logprobs: list[torch.Tensor | None], + top_k: list[TopK | None], + logits: list[torch.Tensor | None], + label_rows: list[torch.Tensor | None], + ) -> None: + model = _language_model(self.runtime.model[0]) + for start in range(0, int(rows.numel()), self.head_chunk_tokens): + chunk_rows = rows[start : start + self.head_chunk_tokens] + chunk_logits = self._logits_from_hidden_rows( + model, + _select_positions(hidden_by_row, chunk_rows), + output_weight=output_weight, + ) + log_z = None + if any( + item.labels is not None or item.request.top_k is not None + for item in items + ): + log_z = torch.logsumexp(chunk_logits.float(), dim=-1) + + for index, item in enumerate(items): + positions = prepared.positions_by_item[index].to(device=rows.device) + offsets, chunk_offsets = _matching_offsets(positions, chunk_rows) + if int(offsets.numel()) == 0: + continue + selected_logits = chunk_logits.index_select(0, chunk_offsets) + item_logits = logits[index] + if item_logits is not None: + item_logits[offsets] = selected_logits + labels = label_rows[index] + item_logprobs = target_logprobs[index] + if item_logprobs is not None and labels is not None: + if log_z is None: + raise RuntimeError("target logprobs require logsumexp") + item_logprobs[offsets] = _target_logprobs_from_full_logits( + selected_logits, + labels.index_select(0, offsets), + log_z.index_select(0, chunk_offsets), + ) + k = item.request.top_k + if k is not None: + if log_z is None: + raise RuntimeError("top_k requires logsumexp") + top_k[index] = _merge_topk( + top_k[index], + offsets, + _topk_from_full_logits( + selected_logits, + k=k, + log_z=log_z.index_select(0, chunk_offsets), + ), + length=int(positions.numel()), + ) + + def _project_vocab_parallel( + self, + items: Sequence[_ForwardItem], + hidden_by_row: torch.Tensor, + rows: torch.Tensor, + *, + row_matches: Sequence[_RowMatch], + item_lengths: Sequence[int], + output_weight: torch.Tensor | None, + target_logprobs: list[torch.Tensor | None], + top_k: list[TopK | None], + label_rows: list[torch.Tensor | None], + ) -> None: + model = _language_model(self.runtime.model[0]) + use_fused_target_ce = _can_use_fused_target_ce(items, label_rows) + fused_target_labels = ( + _consistent_row_labels( + label_rows, + row_matches, + row_count=int(rows.numel()), + device=rows.device, + ) + if use_fused_target_ce + else None + ) + if fused_target_labels is not None: + row_target_logprobs = torch.empty( + int(rows.numel()), + device=rows.device, + dtype=torch.float32, + ) + for start in range(0, int(rows.numel()), self.head_chunk_tokens): + chunk_rows = rows[start : start + self.head_chunk_tokens] + local_logits = self._local_logits_from_hidden_rows( + model, + _select_positions(hidden_by_row, chunk_rows), + output_weight=output_weight, + ) + row_target_logprobs[ + start : start + int(chunk_rows.numel()) + ] = -model.compute_language_model_loss( + fused_target_labels[ + start : start + int(chunk_rows.numel()) + ].unsqueeze(0), + local_logits.unsqueeze(1), + ).reshape(-1) + _scatter_row_target_logprobs( + row_target_logprobs, + row_matches, + label_rows, + target_logprobs, + ) + return + + reference_target_labels = ( + _reference_row_labels( + label_rows, + row_matches, + row_count=int(rows.numel()), + device=rows.device, + ) + if _can_use_reference_target_ce(items, label_rows) + else None + ) + if reference_target_labels is not None: + for start in range(0, int(rows.numel()), self.head_chunk_tokens): + chunk_rows = rows[start : start + self.head_chunk_tokens] + local_logits = self._local_logits_from_hidden_rows( + model, + _select_positions(hidden_by_row, chunk_rows), + output_weight=output_weight, + ) + chunk_reference_labels = reference_target_labels[ + start : start + int(chunk_rows.numel()) + ] + reference_loss = model.compute_language_model_loss( + chunk_reference_labels.unsqueeze(0), + local_logits.unsqueeze(1), + ).reshape(-1) + reference_logits = _vocab_parallel_target_logits( + local_logits, + chunk_reference_labels, + ) + log_z = reference_logits + reference_loss + for index, item_logprobs in enumerate(target_logprobs): + labels = label_rows[index] + if item_logprobs is None or labels is None: + continue + offsets, chunk_offsets = _match_chunk_offsets( + row_matches[index], + start=start, + end=start + int(chunk_rows.numel()), + ) + if int(offsets.numel()) == 0: + continue + item_logprobs[offsets] = _vocab_parallel_target_logprobs( + local_logits, + labels.index_select(0, offsets), + log_z.index_select(0, chunk_offsets), + row_offsets=chunk_offsets, + ) + return + + max_top_k = max( + (int(item.request.top_k or 0) for item in items if not item.request.logits), + default=0, + ) + for start in range(0, int(rows.numel()), self.head_chunk_tokens): + chunk_rows = rows[start : start + self.head_chunk_tokens] + local_logits = self._local_logits_from_hidden_rows( + model, + _select_positions(hidden_by_row, chunk_rows), + output_weight=output_weight, + ) + topk_stats = _try_triton_local_topk_stats(local_logits, k=max_top_k) + logsumexp_stats = ( + _try_triton_local_logsumexp_stats(local_logits) + if topk_stats is None + else None + ) + if topk_stats is not None: + local_max, local_sum, _, _ = topk_stats + local_max = local_max.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + global_sum = _all_reduce_tensor_parallel_sum( + local_sum * torch.exp(local_max - global_max) + ) + log_z = global_max + torch.log(global_sum) + elif logsumexp_stats is not None: + local_max, local_sum = logsumexp_stats + local_max = local_max.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + global_sum = _all_reduce_tensor_parallel_sum( + local_sum * torch.exp(local_max - global_max) + ) + log_z = global_max + torch.log(global_sum) + else: + log_z = _vocab_parallel_log_z(local_logits) + + logits_topk: tuple[torch.Tensor, torch.Tensor] | None = None + if logsumexp_stats is not None and max_top_k > 0: + local_k = min(max_top_k, int(local_logits.shape[1])) + local_values, local_tokens = torch.topk(local_logits, k=local_k, dim=-1) + logits_topk = (local_values.float(), local_tokens) + + for index, item in enumerate(items): + if item.request.logits: + continue + offsets, chunk_offsets = _match_chunk_offsets( + row_matches[index], + start=start, + end=start + int(chunk_rows.numel()), + ) + if int(offsets.numel()) == 0: + continue + selected_log_z = log_z.index_select(0, chunk_offsets) + labels = label_rows[index] + item_logprobs = target_logprobs[index] + if item_logprobs is not None and labels is not None: + item_logprobs[offsets] = _vocab_parallel_target_logprobs( + local_logits, + labels.index_select(0, offsets), + selected_log_z, + row_offsets=chunk_offsets, + ) + k = item.request.top_k + if k is not None: + if topk_stats is not None: + _, _, local_values, local_tokens = topk_stats + top_k[index] = _merge_topk( + top_k[index], + offsets, + _vocab_parallel_topk_from_local( + local_values.index_select(0, chunk_offsets), + local_tokens.index_select(0, chunk_offsets), + k=k, + log_z=selected_log_z, + vocab_start=_vocab_range(local_logits)[0], + ), + length=item_lengths[index], + ) + continue + if logits_topk is not None: + local_values, local_tokens = logits_topk + top_k[index] = _merge_topk( + top_k[index], + offsets, + _vocab_parallel_topk_from_local( + local_values.index_select(0, chunk_offsets), + local_tokens.index_select(0, chunk_offsets), + k=k, + log_z=selected_log_z, + vocab_start=_vocab_range(local_logits)[0], + ), + length=item_lengths[index], + ) + continue + selected_logits = local_logits.index_select(0, chunk_offsets) + top_k[index] = _merge_topk( + top_k[index], + offsets, + _vocab_parallel_topk( + selected_logits, + k=k, + log_z=selected_log_z, + ), + length=item_lengths[index], + ) + + def _logits_from_hidden_rows( + self, + model: "GPTModel", + hidden: torch.Tensor, + *, + output_weight: torch.Tensor | None, + ) -> torch.Tensor: + local_logits = self._local_logits_from_hidden_rows( + model, + hidden, + output_weight=output_weight, + ) + return _batch_seq_logits( + self._gather_tensor_parallel_logits(local_logits.unsqueeze(1)), + seq_len=int(hidden.shape[0]), + ).squeeze(0) + + def _local_logits_from_hidden_rows( + self, + model: "GPTModel", + hidden: torch.Tensor, + *, + output_weight: torch.Tensor | None, + ) -> torch.Tensor: + output_layer = model.output_layer + sequence_parallel = bool(getattr(output_layer, "sequence_parallel", False)) + if sequence_parallel: + output_layer.sequence_parallel = False + try: + logits, _ = output_layer( + hidden.unsqueeze(1), + weight=output_weight, + runtime_gather_output=None, + ) + finally: + if sequence_parallel: + output_layer.sequence_parallel = True + return _batch_seq_logits( + model._scale_logits(logits), + seq_len=int(hidden.shape[0]), + ).squeeze(0) + + def _gather_sequence_parallel_hidden(self, hidden: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return hidden.squeeze(1) + from megatron.core import tensor_parallel + + gathered = tensor_parallel.gather_from_sequence_parallel_region( + hidden, + tensor_parallel_output_grad=True, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ) + return cast(torch.Tensor, gathered).squeeze(1) + + def _prepare_packed_forward( + self, + batch: _PackedForwardBatch, + ) -> _PreparedPackedForward: + topology = self._topology() + batch = _pad_packed_batch(batch, multiple=int(topology.tp)) + if int(topology.cp) > 1: + return self._prepare_context_parallel_forward(batch, topology=topology) + from art.megatron.shared_prefix_state import create_shared_prefix_state + + handler = self._handler() + provider = self._provider() + return _PreparedPackedForward( + tokens=batch.tokens.to(self.device), + position_ids=batch.position_ids.to(self.device), + attention_state=create_shared_prefix_state( + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + target_device=self.device, + build_gdn_execution_spec=handler.build_gdn_execution_spec, + attention_head_dim=provider.kv_channels, + attention_value_head_dim=provider.kv_channels, + ), + packed_seq_params=None, + positions_by_item=batch.positions_by_item, + source_positions_by_item=tuple( + torch.arange( + int(positions.numel()), + dtype=torch.long, + device=positions.device, + ) + for positions in batch.positions_by_item + ), + ) + + def _prepare_context_parallel_forward( + self, + batch: _PackedForwardBatch, + *, + topology: "ParallelTopology", + ) -> _PreparedPackedForward: + from megatron.core import parallel_state as ps + + from art.megatron.context_parallel.runtime import ( + _dispatch_tensor, + prepare_cp_micro, + ) + from art.megatron.training.microbatches import ( + _context_parallel_config_for_provider, + ) + from art.preprocessing.pack import PackedTensors + + assistant_mask = torch.ones_like(batch.tokens, dtype=torch.bool) + sparse_micro: PackedTensors = { + "tokens": batch.tokens, + "group_ids": batch.group_ids, + "parent_ids": batch.parent_ids, + "input_pos": batch.position_ids, + "assistant_mask": assistant_mask, + "logprobs": torch.full_like( + batch.tokens, float("nan"), dtype=torch.float32 + ), + "advantages": torch.zeros_like(batch.tokens, dtype=torch.float32), + "weights": assistant_mask.to(dtype=torch.float32), + "pixel_values": [None], + "image_grid_thw": [None], + "moe_routing_replay": None, + } + handler = self._handler() + prepared = prepare_cp_micro( + micro=sparse_micro, + topology=topology, + config=_context_parallel_config_for_provider(self._provider(), self.device), + cp_group=ps.get_context_parallel_group(check_initialized=False), + cp_rank=ps.get_context_parallel_rank(), + build_gdn_execution_spec=handler.build_gdn_execution_spec, + target_device=self.device, + ) + if prepared.rank_plan is None: + raise RuntimeError("CP forward preparation did not return a rank plan") + local_positions = _dispatch_tensor( + torch.arange( + int(batch.tokens.shape[1]), + dtype=torch.long, + ).unsqueeze(0), + rank_plan=prepared.rank_plan, + pad_value=-1, + pad_multiple=prepared.pad_multiple, + ) + local_position_pairs = tuple( + _local_position_pairs(local_positions, positions) + for positions in batch.positions_by_item + ) + return _PreparedPackedForward( + tokens=prepared.tensors.tokens, + position_ids=prepared.tensors.input_pos, + attention_state=cast("ArtContextParallelState", prepared.attention_state), + packed_seq_params=prepared.packed_seq_params, + positions_by_item=tuple(pair[0] for pair in local_position_pairs), + source_positions_by_item=tuple(pair[1] for pair in local_position_pairs), + ) + + def _topology(self) -> "ParallelTopology": + from art.megatron.train import _infer_parallel_topology + + return _infer_parallel_topology(self.runtime.model) + + def _gather_tensor_parallel_logits(self, logits: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return logits + from megatron.core import tensor_parallel + + return cast( + torch.Tensor, + tensor_parallel.gather_from_tensor_model_parallel_region(logits), + ) + + def _configure_optimizer(self, params: AdamParams) -> None: + optimizer = self._optimizer() + config = cast("OptimizerConfig | None", optimizer.config) + if config is not None: + config.lr = params.learning_rate + config.adam_beta1 = params.beta1 + config.adam_beta2 = params.beta2 + config.weight_decay = params.weight_decay + config.clip_grad = params.grad_clip_norm + for group in optimizer.param_groups: + param_group = cast(MutableMapping[str, object], group) + param_group["lr"] = params.learning_rate + param_group["weight_decay"] = params.weight_decay + if "betas" in param_group: + param_group["betas"] = (params.beta1, params.beta2) + + def _scale_main_grads(self, scale: float) -> None: + if scale == 1.0: + return + for chunk in self.runtime.model: + for param in chunk.parameters(): + grad = getattr(param, "main_grad", None) + if isinstance(grad, torch.Tensor): + grad.mul_(scale) + elif param.grad is not None: + param.grad.mul_(scale) + + +def _as_1d_long(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + tensor = tensor.reshape(-1) + if int(tensor.numel()) == 0: + raise ValueError(f"{name} must not be empty") + return tensor.to(dtype=torch.long) + + +def _as_target_tokens( + tensor: torch.Tensor, + input_tokens: torch.Tensor, + input_ids: torch.Tensor, +) -> torch.Tensor: + labels = tensor.to(dtype=torch.long) + if int(labels.numel()) == 0: + raise ValueError("target_tokens must not be empty") + if tuple(labels.shape) == tuple(input_tokens.shape): + return labels.reshape(-1) + + input_shape = tuple(input_tokens.shape) + if ( + labels.ndim > input_tokens.ndim + and tuple(labels.shape[: input_tokens.ndim]) == input_shape + ): + return labels.reshape( + int(input_ids.numel()), *labels.shape[input_tokens.ndim :] + ) + if labels.ndim >= 1 and int(labels.shape[0]) == int(input_ids.numel()): + return labels + raise ValueError( + "target_tokens must match input_tokens or add trailing target dimensions: " + f"input_tokens={tuple(input_tokens.shape)} target_tokens={tuple(labels.shape)}" + ) + + +def _validate_top_k(top_k: int | None, model: "GPTModel") -> None: + if top_k is None: + return + if top_k < 1: + raise ValueError("top_k must be >= 1") + vocab_size = _padded_vocab_size(model) + if top_k > vocab_size: + raise ValueError(f"top_k={top_k} exceeds vocabulary size {vocab_size}") + + +def _is_native_target_only(items: Sequence[_ForwardItem]) -> bool: + return all( + item.labels is not None + and item.labels.ndim == 1 + and item.request.top_k is None + and not item.request.logits + and not item.request.hidden_states + for item in items + ) + + +def _pack_forward_items( + items: Sequence[_ForwardItem], + *, + max_depth: int, +) -> _PackedForwardBatch: + input_tensors = tuple(item.input_ids for item in items) + pack = pack_shared_prefixes(input_tensors, max_depth=max_depth) + + return _PackedForwardBatch( + tokens=pack.tokens, + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + position_ids=pack.position_ids, + positions_by_item=pack.positions_by_sequence, + ) + + +def _pad_packed_batch( + batch: _PackedForwardBatch, + *, + multiple: int, +) -> _PackedForwardBatch: + if multiple <= 1: + return batch + seq_len = int(batch.tokens.shape[1]) + pad = -seq_len % multiple + if pad == 0: + return batch + + device = batch.tokens.device + next_group = ( + int(batch.group_ids.max().item()) + 1 if int(batch.group_ids.numel()) else 1 + ) + pad_group_ids = torch.arange( + next_group, + next_group + pad, + dtype=batch.group_ids.dtype, + device=device, + ).unsqueeze(0) + return _PackedForwardBatch( + tokens=torch.cat( + ( + batch.tokens, + torch.zeros((1, pad), dtype=batch.tokens.dtype, device=device), + ), + dim=1, + ), + group_ids=torch.cat((batch.group_ids, pad_group_ids), dim=1), + parent_ids=torch.cat((batch.parent_ids, pad_group_ids), dim=1), + position_ids=torch.cat( + ( + batch.position_ids, + torch.zeros((1, pad), dtype=batch.position_ids.dtype, device=device), + ), + dim=1, + ), + positions_by_item=batch.positions_by_item, + ) + + +def _language_model(model: torch.nn.Module) -> "GPTModel": + module: object = model + while hasattr(module, "module"): + module = getattr(module, "module") + if hasattr(module, "_preprocess") and hasattr(module, "decoder"): + return cast("GPTModel", module) + language_model = getattr(module, "language_model", None) + if language_model is not None: + return cast("GPTModel", language_model) + raise RuntimeError("expected a Megatron GPT model") + + +def _empty_logits_like_positions( + positions: torch.Tensor, + model: "GPTModel", + like: torch.Tensor, +) -> torch.Tensor: + return torch.empty( + (int(positions.numel()), _padded_vocab_size(model)), + device=like.device, + dtype=like.dtype, + ) + + +def _padded_vocab_size(model: "GPTModel") -> int: + vocab_size = getattr(getattr(model, "config", None), "padded_vocab_size", None) + if vocab_size is None: + vocab_size = getattr(model, "vocab_size", None) + if vocab_size is None: + raise RuntimeError("could not determine full padded vocabulary size") + return int(vocab_size) + + +def _target_logprobs_from_full_logits( + logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, +) -> torch.Tensor: + return _call_compiled(_target_logprobs_from_full_logits_impl, logits, labels, log_z) + + +def _target_logprobs_from_full_logits_impl( + logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, +) -> torch.Tensor: + flat_labels = labels.clamp_min(0).reshape(int(labels.shape[0]), -1) + target_logits = logits.gather(1, flat_labels).float().reshape(labels.shape) + return _finish_target_logprobs(target_logits, labels, log_z) + + +def _vocab_parallel_target_logprobs( + local_logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, + *, + row_offsets: torch.Tensor | None = None, +) -> torch.Tensor: + target_logits = _vocab_parallel_target_logits( + local_logits, + labels, + row_offsets=row_offsets, + ) + return _call_compiled(_finish_target_logprobs, target_logits, labels, log_z) + + +def _vocab_parallel_target_logits( + local_logits: torch.Tensor, + labels: torch.Tensor, + *, + row_offsets: torch.Tensor | None = None, +) -> torch.Tensor: + start, _ = _vocab_range(local_logits) + if row_offsets is None: + local_target_logits = _call_compiled( + _owned_target_logits, + local_logits, + labels, + start, + ) + else: + local_target_logits = _call_compiled( + _owned_target_logits_for_rows, + local_logits, + labels, + start, + row_offsets, + ) + return _all_reduce_tensor_parallel_sum(local_target_logits) + + +def _owned_target_logits( + local_logits: torch.Tensor, + labels: torch.Tensor, + vocab_start: int, +) -> torch.Tensor: + flat_labels = labels.reshape(int(labels.shape[0]), -1) + local_labels = flat_labels - vocab_start + owns_label = ( + (flat_labels != -100) + & (local_labels >= 0) + & (local_labels < int(local_logits.shape[1])) + ) + selected = local_logits.gather( + 1, + local_labels.clamp(0, int(local_logits.shape[1]) - 1), + ).float() + return selected.masked_fill(~owns_label, 0.0).reshape(labels.shape) + + +def _owned_target_logits_for_rows( + local_logits: torch.Tensor, + labels: torch.Tensor, + vocab_start: int, + row_offsets: torch.Tensor, +) -> torch.Tensor: + flat_labels = labels.reshape(int(labels.shape[0]), -1) + local_labels = flat_labels - vocab_start + owns_label = ( + (flat_labels != -100) + & (local_labels >= 0) + & (local_labels < int(local_logits.shape[1])) + ) + rows = row_offsets.reshape(int(row_offsets.shape[0]), 1).expand_as(flat_labels) + selected = local_logits[ + rows, + local_labels.clamp(0, int(local_logits.shape[1]) - 1), + ].float() + return selected.masked_fill(~owns_label, 0.0).reshape(labels.shape) + + +def _finish_target_logprobs( + target_logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, +) -> torch.Tensor: + log_z = log_z.reshape(int(log_z.shape[0]), *((1,) * (int(labels.ndim) - 1))) + return (target_logits.float() - log_z).masked_fill(labels == -100, 0.0) + + +def _valid_target_offsets(labels: torch.Tensor) -> torch.Tensor: + if int(labels.shape[0]) == 0: + return torch.empty(0, dtype=torch.long, device=labels.device) + valid = labels != -100 + if labels.ndim > 1: + valid = valid.reshape(int(labels.shape[0]), -1).any(dim=1) + return torch.nonzero(valid, as_tuple=False).reshape(-1) + + +def _can_use_fused_target_ce( + items: Sequence[_ForwardItem], + label_rows: Sequence[torch.Tensor | None], +) -> bool: + return all(item.request.top_k is None for item in items) and all( + labels is None or labels.ndim == 1 for labels in label_rows + ) + + +def _can_use_reference_target_ce( + items: Sequence[_ForwardItem], + label_rows: Sequence[torch.Tensor | None], +) -> bool: + return ( + os.environ.get("ART_TRAINER_RANK_REFERENCE_TARGET_CE", "0").lower() + not in {"0", "false"} + and all( + item.request.top_k is None and not item.request.logits for item in items + ) + and any(labels is not None and labels.ndim > 1 for labels in label_rows) + ) + + +def _reference_row_labels( + label_rows: Sequence[torch.Tensor | None], + row_matches: Sequence[_RowMatch], + *, + row_count: int, + device: torch.device, +) -> torch.Tensor | None: + references = torch.full((row_count,), -100, dtype=torch.long, device=device) + for labels, match in zip(label_rows, row_matches, strict=True): + if labels is None or int(match.source_offsets.numel()) == 0: + continue + selected = labels.index_select(0, match.source_offsets).reshape( + int(match.source_offsets.numel()), + -1, + ) + valid = selected != -100 + has_label = valid.any(dim=1) + if not bool(has_label.any()): + continue + candidates = selected.gather( + 1, + valid.to(torch.int64).argmax(dim=1, keepdim=True), + ).squeeze(1) + row_offsets = match.row_offsets.index_select( + 0, + torch.nonzero(has_label, as_tuple=False).reshape(-1), + ) + candidates = candidates.masked_select(has_label) + unset = references.index_select(0, row_offsets) == -100 + if bool(unset.any()): + references[row_offsets.masked_select(unset)] = candidates.masked_select( + unset + ) + if bool((references == -100).any()): + return None + return references + + +def _consistent_row_labels( + label_rows: Sequence[torch.Tensor | None], + row_matches: Sequence[_RowMatch], + *, + row_count: int, + device: torch.device, +) -> torch.Tensor | None: + labels = torch.full( + (row_count,), + -100, + dtype=torch.long, + device=device, + ) + has_label = torch.zeros_like(labels, dtype=torch.bool) + for item_labels, match in zip(label_rows, row_matches, strict=True): + if item_labels is None: + continue + if int(match.source_offsets.numel()) == 0: + continue + selected_labels = item_labels.index_select(0, match.source_offsets) + keep = selected_labels != -100 + if not bool(keep.any().item()): + continue + kept_row_offsets = match.row_offsets[keep] + kept_labels = selected_labels[keep] + existing = labels.index_select(0, kept_row_offsets) + seen = has_label.index_select(0, kept_row_offsets) + if bool(((existing != kept_labels) & seen).any().item()): + return None + labels.index_copy_(0, kept_row_offsets, kept_labels) + has_label.index_fill_(0, kept_row_offsets, True) + return labels + + +def _scatter_row_target_logprobs( + row_target_logprobs: torch.Tensor, + row_matches: Sequence[_RowMatch], + label_rows: Sequence[torch.Tensor | None], + target_logprobs: list[torch.Tensor | None], +) -> None: + for match, labels, item_logprobs in zip( + row_matches, + label_rows, + target_logprobs, + strict=True, + ): + if labels is None or item_logprobs is None: + continue + if int(match.source_offsets.numel()) == 0: + continue + item_logprobs[match.source_offsets] = row_target_logprobs.index_select( + 0, + match.row_offsets, + ) + + +def _topk_from_full_logits( + logits: torch.Tensor, + *, + k: int, + log_z: torch.Tensor, +) -> TopK: + if k > int(logits.shape[1]): + raise ValueError(f"top_k={k} exceeds vocabulary size {int(logits.shape[1])}") + values, tokens = torch.topk(logits.float(), k=k, dim=-1) + return TopK(logprobs=values - log_z.unsqueeze(1), tokens=tokens) + + +def _vocab_parallel_topk( + local_logits: torch.Tensor, + *, + k: int, + log_z: torch.Tensor, +) -> TopK: + start, _ = _vocab_range(local_logits) + local_k = min(k, int(local_logits.shape[1])) + local_values, local_tokens = torch.topk(local_logits.float(), k=local_k, dim=-1) + local_values = local_values - log_z.unsqueeze(1) + local_tokens = local_tokens + start + + from megatron.core import parallel_state as ps + + tp_size = int(ps.get_tensor_model_parallel_world_size()) + if tp_size <= 1: + return TopK(logprobs=local_values, tokens=local_tokens) + + from torch.distributed.nn.functional import all_gather + + group = ps.get_tensor_model_parallel_group(check_initialized=False) + gathered_values = cast(tuple[torch.Tensor, ...], all_gather(local_values, group)) + gathered_tokens = [torch.empty_like(local_tokens) for _ in range(tp_size)] + dist.all_gather(gathered_tokens, local_tokens, group=group) + values = torch.cat(gathered_values, dim=1) + tokens = torch.cat(gathered_tokens, dim=1) + if k > int(values.shape[1]): + raise ValueError(f"top_k={k} exceeds vocabulary size {int(values.shape[1])}") + top_values, top_offsets = torch.topk(values, k=k, dim=-1) + return TopK(logprobs=top_values, tokens=tokens.gather(1, top_offsets)) + + +def _try_triton_local_topk_stats( + local_logits: torch.Tensor, + *, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None: + if k <= 0: + return None + if k > _triton_fused_topk_max(): + return None + if not local_logits.is_cuda: + return None + if _triton_topk_disabled(): + return None + if int(local_logits.shape[0]) < _triton_min_rows(): + return None + try: + from art.megatron.trainer_rank_topk import local_topk_stats + + stats = local_topk_stats( + local_logits, + k=min(k, int(local_logits.shape[1])), + ) + except Exception: + if _triton_topk_strict(): + raise + return None + return stats.local_max, stats.local_sum, stats.values, stats.tokens + + +def _try_triton_local_logsumexp_stats( + local_logits: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor] | None: + if not local_logits.is_cuda: + return None + if _triton_topk_disabled(): + return None + if int(local_logits.shape[0]) < _triton_min_rows(): + return None + try: + from art.megatron.trainer_rank_topk import local_logsumexp_stats + + stats = local_logsumexp_stats(local_logits) + except Exception: + if _triton_topk_strict(): + raise + return None + return stats.local_max, stats.local_sum + + +def _triton_topk_disabled() -> bool: + return os.environ.get("ART_TRAINER_RANK_TRITON_TOPK", "1").lower() in { + "0", + "false", + } + + +def _triton_topk_strict() -> bool: + return os.environ.get("ART_TRAINER_RANK_TRITON_TOPK", "1").lower() == "strict" + + +def _triton_fused_topk_max() -> int: + # H200 measurements: fused top-k wins through k=10; above that the + # logsumexp-only Triton path plus torch.topk scales better. + return int(os.environ.get("ART_TRAINER_RANK_TRITON_FUSED_TOPK_MAX", "10")) + + +def _triton_min_rows() -> int: + # Below this, Triton launch overhead usually costs more than the memory saved. + return int(os.environ.get("ART_TRAINER_RANK_TRITON_MIN_ROWS", "64")) + + +def _vocab_parallel_topk_from_local( + local_values: torch.Tensor, + local_tokens: torch.Tensor, + *, + k: int, + log_z: torch.Tensor, + vocab_start: int, +) -> TopK: + local_k = min(k, int(local_values.shape[1])) + local_values = local_values[:, :local_k] - log_z.unsqueeze(1) + local_tokens = local_tokens[:, :local_k] + vocab_start + + from megatron.core import parallel_state as ps + + tp_size = int(ps.get_tensor_model_parallel_world_size()) + if tp_size <= 1: + if k > int(local_values.shape[1]): + raise ValueError( + f"top_k={k} exceeds vocabulary size {int(local_values.shape[1])}" + ) + return TopK(logprobs=local_values, tokens=local_tokens) + + from torch.distributed.nn.functional import all_gather + + group = ps.get_tensor_model_parallel_group(check_initialized=False) + gathered_values = cast(tuple[torch.Tensor, ...], all_gather(local_values, group)) + gathered_tokens = [torch.empty_like(local_tokens) for _ in range(tp_size)] + dist.all_gather(gathered_tokens, local_tokens, group=group) + values = torch.cat(gathered_values, dim=1) + tokens = torch.cat(gathered_tokens, dim=1) + if k > int(values.shape[1]): + raise ValueError(f"top_k={k} exceeds vocabulary size {int(values.shape[1])}") + top_values, top_offsets = torch.topk(values, k=k, dim=-1) + return TopK(logprobs=top_values, tokens=tokens.gather(1, top_offsets)) + + +def _merge_topk( + current: TopK | None, + offsets: torch.Tensor, + values: TopK, + *, + length: int, +) -> TopK: + if current is None: + current = TopK( + logprobs=torch.empty( + (length, int(values.logprobs.shape[1])), + device=values.logprobs.device, + dtype=values.logprobs.dtype, + ), + tokens=torch.empty( + (length, int(values.tokens.shape[1])), + device=values.tokens.device, + dtype=values.tokens.dtype, + ), + ) + current.logprobs[offsets] = values.logprobs + current.tokens[offsets] = values.tokens + return current + + +def _vocab_parallel_log_z(local_logits: torch.Tensor) -> torch.Tensor: + local_logits = local_logits.float() + local_max = local_logits.max(dim=-1).values.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + local_sum = _call_compiled(_local_vocab_exp_sum, local_logits, global_max) + global_sum = _all_reduce_tensor_parallel_sum(local_sum) + return global_max + torch.log(global_sum) + + +def _local_vocab_exp_sum( + local_logits: torch.Tensor, + global_max: torch.Tensor, +) -> torch.Tensor: + return torch.exp(local_logits.float() - global_max.unsqueeze(1)).sum(dim=-1) + + +def _vocab_range(local_logits: torch.Tensor) -> tuple[int, int]: + from megatron.core import parallel_state as ps + + local_size = int(local_logits.shape[1]) + rank = int(ps.get_tensor_model_parallel_rank()) + start = rank * local_size + return start, start + local_size + + +def _all_reduce_tensor_parallel_sum(tensor: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return tensor + from torch.distributed.nn.functional import all_reduce + + return cast( + torch.Tensor, + all_reduce( + tensor, + op=dist.ReduceOp.SUM, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ), + ) + + +def _all_reduce_tensor_parallel_max(tensor: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return tensor + output = tensor.clone() + dist.all_reduce( + output, + op=dist.ReduceOp.MAX, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ) + return output + + +def _call_compiled(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + if os.environ.get("ART_TRAINER_RANK_COMPILE", "0").lower() in {"0", "false"}: + return fn(*args, **kwargs) + compiled = _COMPILED_FUNCTIONS.get(fn) + if compiled is None: + compiled = cast(Callable[..., object], torch.compile(fn, dynamic=True)) + _COMPILED_FUNCTIONS[fn] = compiled + try: + return cast(Callable[P, R], compiled)(*args, **kwargs) + except Exception: + return fn(*args, **kwargs) + + +def _matching_offsets( + positions: torch.Tensor, + chunk_rows: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + if int(positions.numel()) == 0 or int(chunk_rows.numel()) == 0: + empty = torch.empty(0, dtype=torch.long, device=positions.device) + return empty, empty + sorted_rows, order = chunk_rows.sort() + indices = torch.searchsorted(sorted_rows, positions) + in_bounds = indices < int(sorted_rows.numel()) + source_offsets = torch.arange( + int(positions.numel()), + device=positions.device, + dtype=torch.long, + )[in_bounds] + found = indices[in_bounds] + keep = sorted_rows.index_select(0, found) == positions.index_select( + 0, + source_offsets, + ) + return source_offsets[keep], order.index_select(0, found[keep]) + + +def _row_matches_by_item( + positions_by_item: Sequence[torch.Tensor], + rows: torch.Tensor, + *, + device: torch.device, +) -> tuple[_RowMatch, ...]: + return tuple( + _row_match(positions.to(device=device), rows) for positions in positions_by_item + ) + + +def _row_match(positions: torch.Tensor, rows: torch.Tensor) -> _RowMatch: + source_offsets, row_offsets = _matching_offsets(positions, rows) + if int(row_offsets.numel()) > 1: + order = row_offsets.argsort() + source_offsets = source_offsets.index_select(0, order) + row_offsets = row_offsets.index_select(0, order) + return _RowMatch(source_offsets=source_offsets, row_offsets=row_offsets) + + +def _match_chunk_offsets( + match: _RowMatch, + *, + start: int, + end: int, +) -> tuple[torch.Tensor, torch.Tensor]: + keep = (match.row_offsets >= start) & (match.row_offsets < end) + source_offsets = match.source_offsets[keep] + return source_offsets, match.row_offsets[keep] - start + + +def _local_position_pairs( + local_global_positions: torch.Tensor, + item_positions: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + flat = local_global_positions.reshape(-1).to(device=item_positions.device) + local_positions = torch.nonzero(flat >= 0, as_tuple=False).reshape(-1) + global_positions = flat.index_select(0, local_positions) + sort_order = global_positions.argsort() + sorted_global_positions = global_positions.index_select(0, sort_order) + sorted_local_positions = local_positions.index_select(0, sort_order) + + indices = torch.searchsorted(sorted_global_positions, item_positions) + in_bounds = indices < int(sorted_global_positions.numel()) + source_offsets = torch.arange( + int(item_positions.numel()), + device=item_positions.device, + dtype=torch.long, + )[in_bounds] + found = indices[in_bounds] + keep = sorted_global_positions.index_select( + 0, found + ) == item_positions.index_select( + 0, + source_offsets, + ) + return ( + sorted_local_positions.index_select(0, found[keep]).to("cpu"), + source_offsets[keep].to("cpu"), + ) + + +def _select_positions(values: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + if int(positions.numel()) == 0: + return values[:0] + return values.index_select(0, positions.to(device=values.device)) + + +def _gather_target_logprobs( + logprobs: torch.Tensor, + labels: torch.Tensor, +) -> torch.Tensor: + if int(labels.shape[0]) == 0: + return torch.empty(labels.shape, device=logprobs.device, dtype=logprobs.dtype) + flat_labels = labels.clamp_min(0).reshape(int(labels.shape[0]), -1) + selected = logprobs.gather(1, flat_labels).reshape(labels.shape) + return selected.masked_fill(labels == -100, 0.0) + + +def _batch_seq_logits(logits: torch.Tensor, *, seq_len: int) -> torch.Tensor: + if int(logits.ndim) != 3: + raise RuntimeError( + f"expected logits with shape [B, S, V] or [S, B, V], got {tuple(logits.shape)}" + ) + if int(logits.shape[0]) == 1 and int(logits.shape[1]) == seq_len: + return logits + if int(logits.shape[0]) == seq_len and int(logits.shape[1]) == 1: + return logits.transpose(0, 1).contiguous() + raise RuntimeError( + f"logits do not match sequence length {seq_len}: {tuple(logits.shape)}" + ) + + +def _materialize(inputs: ForwardInputs) -> ForwardInputs: + if isinstance(inputs, ForwardInput): + return inputs + return [_materialize(item) for item in inputs] + + +def _flatten(inputs: ForwardInputs) -> Iterator[AnyForwardInput]: + if isinstance(inputs, ForwardInput): + yield inputs + return + for item in inputs: + yield from _flatten(item) + + +def _unflatten( + template: ForwardInputs, outputs: Iterator[AnyForwardOutput] +) -> ForwardOutputs: + if isinstance(template, ForwardInput): + return next(outputs) + return [_unflatten(item, outputs) for item in template] + + +__all__ = [ + "AdamParams", + "ForwardInput", + "ForwardOutput", + "MicroBatch", + "TopK", + "TrainerRank", +] diff --git a/src/art/megatron/trainer_rank_topk.py b/src/art/megatron/trainer_rank_topk.py new file mode 100644 index 000000000..77c27fb4c --- /dev/null +++ b/src/art/megatron/trainer_rank_topk.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +import triton +import triton.language as tl + + +@dataclass(frozen=True) +class LocalTopKStats: + local_max: torch.Tensor + local_sum: torch.Tensor + values: torch.Tensor + tokens: torch.Tensor + + +@dataclass(frozen=True) +class LocalLogSumExpStats: + local_max: torch.Tensor + local_sum: torch.Tensor + + +@triton.jit +def _topk_stage1_kernel( + logits_ptr, + partial_max_ptr, + partial_sum_ptr, + partial_values_ptr, + partial_tokens_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + n_blocks: tl.constexpr, + k: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + values = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + + block_max = tl.max(values, axis=0) + block_sum = tl.sum(tl.exp(values - block_max), axis=0) + partial_offset = row * n_blocks + block + tl.store(partial_max_ptr + partial_offset, block_max) + tl.store(partial_sum_ptr + partial_offset, block_sum) + + work = values + arange = tl.arange(0, block_v) + for slot in tl.static_range(0, k): + top_value, top_index = tl.max( + work, + axis=0, + return_indices=True, + return_indices_tie_break_left=True, + ) + output_offset = (partial_offset * k) + slot + tl.store(partial_values_ptr + output_offset, top_value) + tl.store( + partial_tokens_ptr + output_offset, + (block * block_v + top_index).to(tl.int64), + ) + work = tl.where(arange == top_index, -float("inf"), work) + + +@triton.jit +def _topk_stage2_kernel( + partial_max_ptr, + partial_sum_ptr, + partial_values_ptr, + partial_tokens_ptr, + local_max_ptr, + local_sum_ptr, + values_ptr, + tokens_ptr, + n_blocks: tl.constexpr, + k: tl.constexpr, + block_b: tl.constexpr, + block_candidates: tl.constexpr, +): + row = tl.program_id(0) + + block_offsets = tl.arange(0, block_b) + block_mask = block_offsets < n_blocks + partial_base = row * n_blocks + block_max = tl.load( + partial_max_ptr + partial_base + block_offsets, + mask=block_mask, + other=-float("inf"), + ) + row_max = tl.max(block_max, axis=0) + block_sum = tl.load( + partial_sum_ptr + partial_base + block_offsets, + mask=block_mask, + other=0.0, + ) + row_sum = tl.sum(block_sum * tl.exp(block_max - row_max), axis=0) + tl.store(local_max_ptr + row, row_max) + tl.store(local_sum_ptr + row, row_sum) + + candidate_offsets = tl.arange(0, block_candidates) + candidate_mask = candidate_offsets < n_blocks * k + candidate_base = row * n_blocks * k + candidates = tl.load( + partial_values_ptr + candidate_base + candidate_offsets, + mask=candidate_mask, + other=-float("inf"), + ) + work = candidates + for slot in tl.static_range(0, k): + top_value, top_index = tl.max( + work, + axis=0, + return_indices=True, + return_indices_tie_break_left=True, + ) + output_offset = row * k + slot + tl.store(values_ptr + output_offset, top_value) + tl.store( + tokens_ptr + output_offset, + tl.load(partial_tokens_ptr + candidate_base + top_index), + ) + work = tl.where(candidate_offsets == top_index, -float("inf"), work) + + +@triton.jit +def _logsumexp_stage1_kernel( + logits_ptr, + partial_max_ptr, + partial_sum_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + n_blocks: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + values = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + + block_max = tl.max(values, axis=0) + partial_offset = row * n_blocks + block + tl.store(partial_max_ptr + partial_offset, block_max) + tl.store( + partial_sum_ptr + partial_offset, tl.sum(tl.exp(values - block_max), axis=0) + ) + + +@triton.jit +def _logsumexp_stage2_kernel( + partial_max_ptr, + partial_sum_ptr, + local_max_ptr, + local_sum_ptr, + n_blocks: tl.constexpr, + block_b: tl.constexpr, +): + row = tl.program_id(0) + block_offsets = tl.arange(0, block_b) + block_mask = block_offsets < n_blocks + partial_base = row * n_blocks + block_max = tl.load( + partial_max_ptr + partial_base + block_offsets, + mask=block_mask, + other=-float("inf"), + ) + row_max = tl.max(block_max, axis=0) + block_sum = tl.load( + partial_sum_ptr + partial_base + block_offsets, + mask=block_mask, + other=0.0, + ) + tl.store(local_max_ptr + row, row_max) + tl.store( + local_sum_ptr + row, tl.sum(block_sum * tl.exp(block_max - row_max), axis=0) + ) + + +@triton.jit +def _topk_backward_kernel( + logits_ptr, + local_max_ptr, + tokens_ptr, + grad_sum_ptr, + grad_values_ptr, + grad_logits_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + k: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + + logits = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + local_max = tl.load(local_max_ptr + row) + grad = tl.load(grad_sum_ptr + row).to(tl.float32) * tl.exp(logits - local_max) + + for slot in tl.static_range(0, k): + token = tl.load(tokens_ptr + row * k + slot) + value_grad = tl.load(grad_values_ptr + row * k + slot).to(tl.float32) + grad += tl.where(offsets == token, value_grad, 0.0) + + tl.store(grad_logits_ptr + row * stride_row + offsets, grad, mask=mask) + + +@triton.jit +def _logsumexp_backward_kernel( + logits_ptr, + local_max_ptr, + grad_sum_ptr, + grad_logits_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + logits = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + local_max = tl.load(local_max_ptr + row) + grad = tl.load(grad_sum_ptr + row).to(tl.float32) * tl.exp(logits - local_max) + tl.store(grad_logits_ptr + row * stride_row + offsets, grad, mask=mask) + + +class _LocalTopKStatsFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, local_logits: torch.Tensor, k: int): + stats = _local_topk_stats_forward(local_logits, k=k) + ctx.save_for_backward(local_logits, stats.local_max, stats.tokens) + ctx.k = k + return stats.local_max, stats.local_sum, stats.values, stats.tokens + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_local_max, grad_local_sum, grad_values, grad_tokens = grad_outputs + del grad_local_max, grad_tokens + logits, local_max, tokens = ctx.saved_tensors + k = int(ctx.k) + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = triton.cdiv(vocab_size, block_v) + + if grad_local_sum is None: + grad_local_sum = torch.zeros_like(local_max) + if grad_values is None: + grad_values = torch.zeros( + (rows, k), + device=logits.device, + dtype=torch.float32, + ) + + grad_logits = torch.empty_like(logits) + _topk_backward_kernel[(rows, n_blocks)]( + logits, + local_max, + tokens, + grad_local_sum.contiguous(), + grad_values.contiguous(), + grad_logits, + logits.stride(0), + vocab_size, # ty: ignore[invalid-argument-type] + k, # ty: ignore[invalid-argument-type] + block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + return grad_logits, None + + +class _LocalLogSumExpStatsFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, local_logits: torch.Tensor): + stats = _local_logsumexp_stats_forward(local_logits) + ctx.save_for_backward(local_logits, stats.local_max) + return stats.local_max, stats.local_sum + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_local_max, grad_local_sum = grad_outputs + del grad_local_max + logits, local_max = ctx.saved_tensors + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = triton.cdiv(vocab_size, block_v) + + if grad_local_sum is None: + grad_local_sum = torch.zeros_like(local_max) + + grad_logits = torch.empty_like(logits) + _logsumexp_backward_kernel[(rows, n_blocks)]( + logits, + local_max, + grad_local_sum.contiguous(), + grad_logits, + logits.stride(0), + vocab_size, # ty: ignore[invalid-argument-type] + block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + return grad_logits + + +def _check_local_logits(local_logits: torch.Tensor) -> torch.Tensor: + if local_logits.ndim != 2: + raise ValueError( + f"expected [rows, vocab] logits, got {tuple(local_logits.shape)}" + ) + if not local_logits.is_cuda: + raise ValueError("local top-k helpers require CUDA logits") + return local_logits.contiguous() + + +def _local_topk_stats_forward(local_logits: torch.Tensor, *, k: int) -> LocalTopKStats: + logits = _check_local_logits(local_logits) + if k < 1 or k > int(local_logits.shape[1]): + raise ValueError( + f"k={k} is outside local vocab size {int(local_logits.shape[1])}" + ) + + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = triton.cdiv(vocab_size, block_v) + block_b = triton.next_power_of_2(n_blocks) + block_candidates = triton.next_power_of_2(n_blocks * k) + + partial_shape = (rows, n_blocks) + partial_topk_shape = (rows, n_blocks, k) + partial_max = torch.empty(partial_shape, device=logits.device, dtype=torch.float32) + partial_sum = torch.empty_like(partial_max) + partial_values = torch.empty( + partial_topk_shape, + device=logits.device, + dtype=torch.float32, + ) + partial_tokens = torch.empty( + partial_topk_shape, + device=logits.device, + dtype=torch.long, + ) + local_max = torch.empty((rows,), device=logits.device, dtype=torch.float32) + local_sum = torch.empty_like(local_max) + values = torch.empty((rows, k), device=logits.device, dtype=torch.float32) + tokens = torch.empty((rows, k), device=logits.device, dtype=torch.long) + + _topk_stage1_kernel[(rows, n_blocks)]( + logits, + partial_max, + partial_sum, + partial_values, + partial_tokens, + logits.stride(0), # ty: ignore[invalid-argument-type] + vocab_size, # ty: ignore[invalid-argument-type] + n_blocks, + k, # ty: ignore[invalid-argument-type] + block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + _topk_stage2_kernel[(rows,)]( + partial_max, + partial_sum, + partial_values, + partial_tokens, + local_max, + local_sum, + values, + tokens, + n_blocks, + k, # ty: ignore[invalid-argument-type] + block_b, + block_candidates, + num_warps=8, # ty: ignore[unknown-argument] + ) + return LocalTopKStats( + local_max=local_max, + local_sum=local_sum, + values=values, + tokens=tokens, + ) + + +def _local_logsumexp_stats_forward(local_logits: torch.Tensor) -> LocalLogSumExpStats: + logits = _check_local_logits(local_logits) + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = triton.cdiv(vocab_size, block_v) + block_b = triton.next_power_of_2(n_blocks) + + partial_shape = (rows, n_blocks) + partial_max = torch.empty(partial_shape, device=logits.device, dtype=torch.float32) + partial_sum = torch.empty_like(partial_max) + local_max = torch.empty((rows,), device=logits.device, dtype=torch.float32) + local_sum = torch.empty_like(local_max) + + _logsumexp_stage1_kernel[(rows, n_blocks)]( + logits, + partial_max, + partial_sum, + logits.stride(0), # ty: ignore[invalid-argument-type] + vocab_size, # ty: ignore[invalid-argument-type] + n_blocks, + block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + _logsumexp_stage2_kernel[(rows,)]( + partial_max, + partial_sum, + local_max, + local_sum, + n_blocks, + block_b, + num_warps=8, # ty: ignore[unknown-argument] + ) + return LocalLogSumExpStats(local_max=local_max, local_sum=local_sum) + + +def local_topk_stats(local_logits: torch.Tensor, *, k: int) -> LocalTopKStats: + logits = local_logits.contiguous() + if not logits.requires_grad: + return _local_topk_stats_forward(logits, k=k) + local_max, local_sum, values, tokens = _LocalTopKStatsFunction.apply(logits, k) + return LocalTopKStats( + local_max=local_max, + local_sum=local_sum, + values=values, + tokens=tokens, + ) + + +def local_logsumexp_stats(local_logits: torch.Tensor) -> LocalLogSumExpStats: + logits = local_logits.contiguous() + if not logits.requires_grad: + return _local_logsumexp_stats_forward(logits) + local_max, local_sum = _LocalLogSumExpStatsFunction.apply(logits) + return LocalLogSumExpStats(local_max=local_max, local_sum=local_sum) diff --git a/src/art/megatron/training/finalize_grads.py b/src/art/megatron/training/finalize_grads.py index cde0e7b06..2c49671fa 100644 --- a/src/art/megatron/training/finalize_grads.py +++ b/src/art/megatron/training/finalize_grads.py @@ -28,6 +28,8 @@ def _iter_named_trainable_parameters( for name, param in model_chunk.named_parameters(): if not param.requires_grad: continue + if getattr(param, "_art_dynamic_lora_slot", False): + continue param_id = id(param) if param_id in seen: continue diff --git a/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py b/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py index 3d3d51d4c..58670a685 100644 --- a/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py +++ b/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py @@ -82,40 +82,19 @@ def test_shared_prefix_attention_matches_flattened_grad_accumulation() -> None: ref_out = torch.zeros_like(packed_out) ref_loss = q_ref.new_zeros(()) - for family in spec.families: - prefix = family.prefix - prefix_grad_used = False - for completion in family.completions: - indices = torch.tensor( - [ - *range(prefix.start, prefix.end), - *range(completion.start, completion.end), - ], - device=q.device, - dtype=torch.long, - ) - row = family.row_index - q_slice = q_ref[row : row + 1].index_select(2, indices) - k_slice = k_ref[row : row + 1].index_select(2, indices) - v_slice = v_ref[row : row + 1].index_select(2, indices) - flat_out = _dense_causal_attention(q_slice, k_slice, v_slice) - - ref_out[row, :, completion.start : completion.end] = flat_out[ - 0, :, prefix.length : - ] - flat_grad = torch.zeros_like(flat_out) - flat_grad[0, :, prefix.length :] = output_grad[ - row, :, completion.start : completion.end - ] - if not prefix_grad_used: - ref_out[row, :, prefix.start : prefix.end] = flat_out[ - 0, :, : prefix.length - ] - flat_grad[0, :, : prefix.length] = output_grad[ - row, :, prefix.start : prefix.end - ] - prefix_grad_used = True - ref_loss = ref_loss + (flat_out * flat_grad).sum() + for segment_index, segment in enumerate(spec.tree_segments): + indices, output_slice = _segment_context_positions(spec, segment_index) + index_tensor = torch.tensor(indices, device=q.device, dtype=torch.long) + row = segment.row_index + q_slice = q_ref[row : row + 1].index_select(2, index_tensor) + k_slice = k_ref[row : row + 1].index_select(2, index_tensor) + v_slice = v_ref[row : row + 1].index_select(2, index_tensor) + flat_out = _dense_causal_attention(q_slice, k_slice, v_slice) + + ref_out[row, :, segment.start : segment.end] = flat_out[0, :, output_slice] + flat_grad = torch.zeros_like(flat_out) + flat_grad[0, :, output_slice] = output_grad[row, :, segment.start : segment.end] + ref_loss = ref_loss + (flat_out * flat_grad).sum() ref_loss.backward() real_mask = _real_token_mask(spec, q.shape, device=q.device) @@ -225,11 +204,27 @@ def _completion_token_mask( spec: Any, shape: torch.Size, *, device: torch.device ) -> torch.Tensor: mask = torch.zeros(shape, device=device, dtype=torch.bool) - for family in spec.families: - for completion in family.completions: - mask[ - family.row_index, - :, - completion.start : completion.end, - ] = True + for index, segment in enumerate(spec.tree_segments): + if spec.tree_parent_indices[index] >= 0: + mask[segment.row_index, :, segment.start : segment.end] = True return mask + + +def _segment_context_positions( + spec: Any, segment_index: int +) -> tuple[list[int], slice]: + path = [] + cursor = segment_index + while cursor >= 0: + path.append(cursor) + cursor = spec.tree_parent_indices[cursor] + path.reverse() + positions = [ + position + for index in path + for position in range( + spec.tree_segments[index].start, spec.tree_segments[index].end + ) + ] + segment_length = spec.tree_segments[segment_index].length + return positions, slice(len(positions) - segment_length, len(positions)) diff --git a/tests/integration/megatron/gdn_shared_prefix/oracles.py b/tests/integration/megatron/gdn_shared_prefix/oracles.py index 3d3f9ae12..3820bbdb5 100644 --- a/tests/integration/megatron/gdn_shared_prefix/oracles.py +++ b/tests/integration/megatron/gdn_shared_prefix/oracles.py @@ -7,6 +7,8 @@ from torch import Tensor import torch.nn.functional as F +from art.megatron.gdn.gdn_shared_prefix import GdnPackedExecutionSpec, GdnSegmentSpec + from .metrics import ( mean_abs_pct, parameter_grad_mean_abs_pct_with_name, @@ -111,23 +113,25 @@ def run_toy_packed( group_ids, parent_ids, min_completions_per_family=1 ) output = torch.zeros_like(hidden) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden[row, family.prefix.start : family.prefix.end] - prefix_out, prefix_conv, prefix_rec = module.forward_segment( - prefix_hidden, - conv_initial=module.zero_conv_state(hidden), - recurrent_initial=module.zero_recurrent_state(hidden), + conv_states: list[Tensor] = [] + rec_states: list[Tensor] = [] + for segment_index, segment in enumerate(spec.tree_segments): + row = segment.row_index + parent_index = spec.tree_parent_indices[segment_index] + if parent_index < 0: + conv_initial = module.zero_conv_state(hidden) + rec_initial = module.zero_recurrent_state(hidden) + else: + conv_initial = conv_states[parent_index] + rec_initial = rec_states[parent_index] + segment_out, conv_final, rec_final = module.forward_segment( + hidden[row, segment.start : segment.end], + conv_initial=conv_initial, + recurrent_initial=rec_initial, ) - output[row, family.prefix.start : family.prefix.end] = prefix_out - for completion in family.completions: - suffix_hidden = hidden[row, completion.start : completion.end] - suffix_out, _, _ = module.forward_segment( - suffix_hidden, - conv_initial=prefix_conv, - recurrent_initial=prefix_rec, - ) - output[row, completion.start : completion.end] = suffix_out + output[row, segment.start : segment.end] = segment_out + conv_states.append(conv_final) + rec_states.append(rec_final) return output @@ -142,26 +146,34 @@ def run_toy_flattened_reference( group_ids, parent_ids, min_completions_per_family=1 ) output = torch.zeros_like(hidden) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden[row, family.prefix.start : family.prefix.end] - prefix_len = family.prefix.length - for child_index, completion in enumerate(family.completions): - suffix_hidden = hidden[row, completion.start : completion.end] - flattened = torch.cat([prefix_hidden, suffix_hidden], dim=0) - flat_out, _, _ = module.forward_segment( - flattened, - conv_initial=module.zero_conv_state(hidden), - recurrent_initial=module.zero_recurrent_state(hidden), - ) - if child_index == 0: - output[row, family.prefix.start : family.prefix.end] = flat_out[ - :prefix_len - ] - output[row, completion.start : completion.end] = flat_out[prefix_len:] + for segment_index, segment in enumerate(spec.tree_segments): + path = _segment_path(spec, segment_index) + flattened = torch.cat( + [hidden[node.row_index, node.start : node.end] for node in path], + dim=0, + ) + flat_out, _, _ = module.forward_segment( + flattened, + conv_initial=module.zero_conv_state(hidden), + recurrent_initial=module.zero_recurrent_state(hidden), + ) + segment_len = segment.length + output[segment.row_index, segment.start : segment.end] = flat_out[-segment_len:] return output +def _segment_path( + spec: GdnPackedExecutionSpec, + segment_index: int, +) -> tuple[GdnSegmentSpec, ...]: + indices = [] + cursor = segment_index + while cursor >= 0: + indices.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(indices)) + + def run_toy_physical_stream( module: ToyStatefulGdn, hidden: Tensor, diff --git a/tests/integration/megatron/gdn_shared_prefix/packed_layout.py b/tests/integration/megatron/gdn_shared_prefix/packed_layout.py index 45a41ff58..a56b801b3 100644 --- a/tests/integration/megatron/gdn_shared_prefix/packed_layout.py +++ b/tests/integration/megatron/gdn_shared_prefix/packed_layout.py @@ -141,7 +141,9 @@ def summarize_case( tensors["group_ids"], tensors["parent_ids"], min_completions_per_family=1 ) suffix_lengths = [ - segment.length for family in spec.families for segment in family.completions + segment.length + for index, segment in enumerate(spec.tree_segments) + if spec.tree_parent_indices[index] >= 0 ] boundary = _boundary_flags(spec, cp_sizes) return GdnCaseSummary( @@ -227,19 +229,49 @@ def _boundary_flags( boundaries = {shard * rank for rank in range(1, cp_size)} if shard * (cp_size - 1) >= spec.real_token_count: flags["empty_trailing_rank"] = True - for family in spec.families: - family_start = _segment_real_start(family.prefix, spec, real_index) - family_end = _segment_real_end(family.completions[-1], spec, real_index) + for root in _root_segments(spec): + descendants = _descendant_segments(spec, root.family_index) + family_segments = (root, *descendants) + family_start = min( + _segment_real_start(segment, spec, real_index) + for segment in family_segments + ) + family_end = max( + _segment_real_end(segment, spec, real_index) + for segment in family_segments + ) if family_start in boundaries or family_end in boundaries: flags["family_boundary_at_partition"] = True - if _crosses_boundary(family.prefix, spec, real_index, boundaries): + if _crosses_boundary(root, spec, real_index, boundaries): flags["cp_boundary_prefix"] = True - for completion in family.completions: + for completion in descendants: if _crosses_boundary(completion, spec, real_index, boundaries): flags["cp_boundary_suffix"] = True return flags +def _root_segments(spec: GdnPackedExecutionSpec) -> tuple[Any, ...]: + return tuple( + segment + for index, segment in enumerate(spec.tree_segments) + if spec.tree_parent_indices[index] < 0 + ) + + +def _descendant_segments( + spec: GdnPackedExecutionSpec, root_index: int +) -> tuple[Any, ...]: + descendants = [] + for index, segment in enumerate(spec.tree_segments): + parent = spec.tree_parent_indices[index] + while parent >= 0: + if parent == root_index: + descendants.append(segment) + break + parent = spec.tree_parent_indices[parent] + return tuple(descendants) + + def _segment_real_start( segment: Any, spec: GdnPackedExecutionSpec, real_index: dict[int, int] ) -> int: diff --git a/tests/integration/megatron/gdn_shared_prefix/parser_import.py b/tests/integration/megatron/gdn_shared_prefix/parser_import.py index ce184d96e..3a473ebf3 100644 --- a/tests/integration/megatron/gdn_shared_prefix/parser_import.py +++ b/tests/integration/megatron/gdn_shared_prefix/parser_import.py @@ -24,6 +24,5 @@ def _load_parser_module() -> ModuleType: _MODULE = _load_parser_module() GdnPackedExecutionSpec: Any = _MODULE.GdnPackedExecutionSpec -build_gdn_cp_segment_schedule: Any = _MODULE.build_gdn_cp_segment_schedule build_gdn_rank_execution_plan: Any = _MODULE.build_gdn_rank_execution_plan parse_gdn_shared_prefix_segments: Any = _MODULE.parse_gdn_shared_prefix_segments diff --git a/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py b/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py index e69fef22b..ee472adaa 100644 --- a/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py +++ b/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Any, Literal, NamedTuple from pydantic import BaseModel, ConfigDict import torch @@ -61,6 +61,57 @@ class GdnChainBoundaryDebug(BaseModel): ] +class _TreeFamily(NamedTuple): + row_index: int + family_index: int + prefix: Any + completions: tuple[Any, ...] + segment_indices: tuple[int, ...] + parent_indices: tuple[int, ...] + + @property + def token_count(self) -> int: + return self.prefix.length + sum(segment.length for segment in self.completions) + + +def _segment_path(spec: Any, segment_index: int) -> tuple[Any, ...]: + path = [] + cursor = segment_index + while cursor >= 0: + path.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(path)) + + +def _tree_families(spec: Any) -> tuple[_TreeFamily, ...]: + families = [] + for root_index, root in enumerate(spec.tree_segments): + if spec.tree_parent_indices[root_index] >= 0: + continue + segment_indices = [root_index] + for index in range(root_index + 1, len(spec.tree_segments)): + parent = spec.tree_parent_indices[index] + while parent >= 0: + if parent == root_index: + segment_indices.append(index) + break + parent = spec.tree_parent_indices[parent] + segments = tuple(spec.tree_segments[index] for index in segment_indices) + families.append( + _TreeFamily( + row_index=root.row_index, + family_index=root_index, + prefix=root, + completions=segments[1:], + segment_indices=tuple(segment_indices), + parent_indices=tuple( + spec.tree_parent_indices[index] for index in segment_indices + ), + ) + ) + return tuple(families) + + def compare_real_gdn_cp1_to_flattened( *, packed_gdn: Any, @@ -300,31 +351,32 @@ def run_real_gdn_flattened_reference( group_ids, parent_ids, min_completions_per_family=1 ) output = torch.zeros_like(hidden_states) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden_states[ - family.prefix.start : family.prefix.end, row : row + 1, : - ] - prefix_len = family.prefix.length - for child_index, completion in enumerate(family.completions): - suffix_hidden = hidden_states[ - completion.start : completion.end, row : row + 1, : - ] - flat_hidden = torch.cat([prefix_hidden, suffix_hidden], dim=0) - flat_out, _, _, _ = _run_gdn_segment( - gdn, - flat_hidden, - conv_initial=_zero_conv_state(gdn, hidden_states, row), - recurrent_initial=_zero_recurrent_state(gdn, hidden_states, row), - output_final_state=False, - ) - if child_index == 0: - output[family.prefix.start : family.prefix.end, row : row + 1, :] = ( - flat_out[:prefix_len] - ) - output[completion.start : completion.end, row : row + 1, :] = flat_out[ - prefix_len: - ] + for segment_index, segment in enumerate(spec.tree_segments): + flat_hidden = torch.cat( + [ + hidden_states[ + node.start : node.end, + node.row_index : node.row_index + 1, + :, + ] + for node in _segment_path(spec, segment_index) + ], + dim=0, + ) + flat_out, _, _, _ = _run_gdn_segment( + gdn, + flat_hidden, + conv_initial=_zero_conv_state(gdn, hidden_states, segment.row_index), + recurrent_initial=_zero_recurrent_state( + gdn, hidden_states, segment.row_index + ), + output_final_state=False, + ) + output[ + segment.start : segment.end, + segment.row_index : segment.row_index + 1, + :, + ] = flat_out[-segment.length :] return output @@ -414,7 +466,7 @@ def _split_gdn_families_by_rank( raise ValueError(f"cp_size must be >= 1, got {cp_size}") ranks: list[list[int]] = [[] for _ in range(cp_size)] loads = [0] * cp_size - for family in spec.families: + for family in _tree_families(spec): rank = min(range(cp_size), key=lambda index: (loads[index], index)) family_tokens = tuple( token @@ -527,7 +579,7 @@ def run_real_gdn_suffix_only_chain_reference( group_ids, parent_ids, min_completions_per_family=0 ) output = torch.zeros_like(hidden_states) - for family in spec.families: + for family in _tree_families(spec): row = family.row_index zero_conv = _zero_conv_state(gdn, hidden_states, batch_size=1) zero_rec = _zero_recurrent_state(gdn, hidden_states, batch_size=1) @@ -579,7 +631,7 @@ def run_real_gdn_chunk_native_reference( group_ids, parent_ids, min_completions_per_family=0 ) output = torch.zeros_like(hidden_states) - for family in spec.families: + for family in _tree_families(spec): _scatter_family_output( output, family, @@ -603,7 +655,7 @@ def run_real_gdn_mixed_cp_reference( output = torch.zeros_like(hidden_states) local_count = 0 chain_count = 0 - for family in spec.families: + for family in _tree_families(spec): if family.token_count <= local_fork_max_tokens: local_count += 1 _scatter_family_output( @@ -753,14 +805,21 @@ def _family_group_tensors( ) -> tuple[Tensor, Tensor]: group_ids = [] parent_ids = [] - prefix_group_id = 0 - group_ids.extend([prefix_group_id] * family.prefix.length) - parent_ids.extend([prefix_group_id] * family.prefix.length) - next_group_id = 1 - for completion in family.completions: - group_ids.extend([next_group_id] * completion.length) - parent_ids.extend([prefix_group_id] * completion.length) - next_group_id += 1 + local_group_by_global: dict[int, int] = {} + for local_group_id, (segment, global_index, parent_index) in enumerate( + zip( + (family.prefix, *family.completions), + family.segment_indices, + family.parent_indices, + strict=True, + ) + ): + local_group_by_global[global_index] = local_group_id + local_parent_id = ( + local_group_id if parent_index < 0 else local_group_by_global[parent_index] + ) + group_ids.extend([local_group_id] * segment.length) + parent_ids.extend([local_parent_id] * segment.length) return ( torch.tensor([group_ids], device=device, dtype=torch.long), torch.tensor([parent_ids], device=device, dtype=torch.long), @@ -883,7 +942,7 @@ def _local_fork_group_tensors( ) parent_ids = torch.full_like(group_ids, -1) next_group_id = 0 - for family in spec.families: + for family in _tree_families(spec): family_segments = (family.prefix, *family.completions) family_tokens = tuple( token_index @@ -898,19 +957,23 @@ def _local_fork_group_tensors( if not all(token_is_local): raise ValueError("local-fork execution requires whole prompt families") - prefix_group_id = next_group_id - next_group_id += 1 - for token_index in family.prefix.linear_indices(spec.sequence_length): - position = local_position[token_index] - group_ids[position] = prefix_group_id - parent_ids[position] = prefix_group_id - for completion in family.completions: - child_group_id = next_group_id + group_by_segment_index: dict[int, int] = {} + for segment, global_index, parent_index in zip( + family_segments, + family.segment_indices, + family.parent_indices, + strict=True, + ): + group_id = next_group_id next_group_id += 1 - for token_index in completion.linear_indices(spec.sequence_length): + group_by_segment_index[global_index] = group_id + parent_group_id = ( + group_id if parent_index < 0 else group_by_segment_index[parent_index] + ) + for token_index in segment.linear_indices(spec.sequence_length): position = local_position[token_index] - group_ids[position] = child_group_id - parent_ids[position] = prefix_group_id + group_ids[position] = group_id + parent_ids[position] = parent_group_id if torch.any(group_ids == -1): raise RuntimeError("local-fork metadata left unassigned token rows") return group_ids.unsqueeze(0), parent_ids.unsqueeze(0) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 2151b41e1..53d5d62e8 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -21,6 +21,7 @@ parse_gdn_shared_prefix_segments, ) from art.megatron.gdn.operator import run_gdn_layer # noqa: E402 +from art.megatron.shared_prefix_packing import pack_shared_prefixes # noqa: E402 from .cases import ( # noqa: E402 GdnFamilyShape, @@ -77,6 +78,34 @@ def test_gdn_cp_packed_sibling_order_matches_cp1_oracle( assert (tmp_path / f"cp1_oracle_sibling_rank_{rank}.ok").read_text() == "ok\n" +@pytest.mark.parametrize("cp_size", (2, 4)) +def test_gdn_cp_tree_chain_matches_cp1_oracle(cp_size: int, tmp_path: Path) -> None: + _skip_without_gpus(cp_size) + port = _find_free_port() + mp.spawn( + _tree_chain_oracle_worker, + args=(cp_size, port, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_chain_rank_{rank}.ok").read_text() == "ok\n" + + +def test_gdn_cp_tree_fuzz_matches_cp1_oracle(tmp_path: Path) -> None: + cp_size = 4 + _skip_without_gpus(cp_size) + port = _find_free_port() + mp.spawn( + _tree_fuzz_oracle_worker, + args=(cp_size, port, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_fuzz_rank_{rank}.ok").read_text() == "ok\n" + + def _cp1_oracle_worker( rank: int, cp_size: int, @@ -126,6 +155,86 @@ def _cp1_oracle_worker( destroy_process_group() +def _tree_chain_oracle_worker( + rank: int, + cp_size: int, + port: int, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{port}", + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + ref_gdn, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size) + _assert_tree_pack_matches_cp1( + "tree_chain", + ref_gdn, + cp_gdn, + _tree_chain_pack(), + rank=rank, + cp_size=cp_size, + seed=9090, + planner_config=_tree_chain_planner_config(), + require_chain=True, + ) + Path(output_dir, f"tree_chain_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + +def _tree_fuzz_oracle_worker( + rank: int, + cp_size: int, + port: int, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{port}", + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + ref_gdn, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size) + for case_index, (name, pack) in enumerate(_tree_fuzz_packs()): + _assert_tree_pack_matches_cp1( + name, + ref_gdn, + cp_gdn, + pack, + rank=rank, + cp_size=cp_size, + seed=9190 + case_index, + planner_config=_tree_fuzz_planner_config(), + require_chain=False, + ) + torch.distributed.barrier() + Path(output_dir, f"tree_fuzz_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + def _assert_case_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -212,6 +321,81 @@ def _assert_case_matches_cp1( ) +def _assert_tree_pack_matches_cp1( + name: str, + ref_gdn: torch.nn.Module, + cp_gdn: torch.nn.Module, + pack: Any, + *, + rank: int, + cp_size: int, + seed: int, + planner_config: GdnPlannerConfig, + require_chain: bool, +) -> None: + zero_parameter_grads(ref_gdn) + zero_parameter_grads(cp_gdn) + group_ids = pack.group_ids.cuda() + parent_ids = pack.parent_ids.cuda() + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) + plan = build_gdn_rank_execution_plan( + spec, + device=group_ids.device, + cp_rank=rank, + cp_size=cp_size, + planner_config=planner_config, + ) + if require_chain: + assert any(plan.tree_chain_buckets_by_depth) + hidden, output_grad = _tree_hidden_and_grad(spec.real_token_count, seed=seed) + ref_hidden = hidden.clone().detach().requires_grad_(True) + ref_out, _ = run_gdn_layer( + ref_gdn, + ref_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + ) + ref_loss = (ref_out * output_grad).sum() + ref_loss.backward() + + flat_hidden = hidden.transpose(0, 1).reshape(-1, hidden.shape[-1]) + flat_grad = output_grad.transpose(0, 1).reshape(-1, output_grad.shape[-1]) + local_index = torch.tensor( + plan.attention_token_indices, device=hidden.device, dtype=torch.long + ) + local_hidden = ( + flat_hidden.index_select(0, local_index) + .unsqueeze(1) + .contiguous() + .detach() + .requires_grad_(True) + ) + local_output_grad = flat_grad.index_select(0, local_index).unsqueeze(1).contiguous() + cp_out, _ = run_gdn_layer( + cp_gdn, + local_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + execution_spec=spec, + execution_plan=plan, + cp_group=torch.distributed.group.WORLD, + ) + cp_loss = (cp_out * local_output_grad).sum() + cp_loss.backward() + _assert_cp_matches_reference( + name, + ref_gdn, + cp_gdn, + ref_hidden, + ref_out, + ref_loss.detach(), + local_hidden, + cp_out, + cp_loss.detach(), + local_index, + ) + + def _assert_sibling_order_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -377,6 +561,126 @@ def _hidden_and_grad( return hidden, grad +def _tree_hidden_and_grad( + sequence_length: int, *, seed: int +) -> tuple[torch.Tensor, torch.Tensor]: + generator = torch.Generator(device="cuda").manual_seed(seed) + hidden = torch.randn( + sequence_length, + 1, + 64, + device="cuda", + dtype=GDN_CORRECTNESS_DTYPE, + generator=generator, + ) + grad = torch.randn( + hidden.shape, + device="cuda", + dtype=GDN_CORRECTNESS_DTYPE, + generator=generator, + ) + torch.distributed.broadcast(hidden, src=0) + torch.distributed.broadcast(grad, src=0) + return hidden, grad + + +def _tree_chain_pack(): + long_root = torch.arange(11, 267) + short_root = torch.arange(1001, 1097) + long_mid = torch.arange(2001, 2641) + other_mid = torch.arange(3001, 3065) + return pack_shared_prefixes( + ( + torch.cat((long_root, torch.tensor([301]))), + torch.cat((long_root, torch.tensor([302]))), + torch.cat((short_root, long_mid, torch.tensor([401]))), + torch.cat((short_root, long_mid, torch.tensor([402]))), + torch.cat((short_root, other_mid, torch.tensor([403]))), + ), + max_depth=2, + ) + + +def _tree_chain_planner_config() -> GdnPlannerConfig: + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=16, + cp_chain_min_total_tokens=128, + cp_chain_min_prefix_only_tokens=128, + max_padding_ratio=4.0, + ) + + +def _tree_fuzz_planner_config() -> GdnPlannerConfig: + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=64, + cp_chain_min_prefix_only_tokens=64, + cp_tree_chain_min_total_tokens=64, + cp_tree_chain_min_prefix_only_tokens=64, + max_padding_ratio=4.0, + ) + + +def _tree_fuzz_packs() -> tuple[tuple[str, Any], ...]: + return ( + ( + "tree_fuzz_duplicates", + pack_shared_prefixes(_duplicate_tree_sequences(), max_depth=4), + ), + ( + "tree_fuzz_ragged_depth4", + pack_shared_prefixes(_random_tree_sequences(13, max_depth=4), max_depth=4), + ), + ( + "tree_fuzz_mixed_tiny_long", + pack_shared_prefixes(_random_tree_sequences(29, max_depth=5), max_depth=5), + ), + ) + + +def _duplicate_tree_sequences() -> tuple[torch.Tensor, ...]: + root = torch.arange(11, 331) + mid_a = torch.arange(1001, 1261) + mid_b = torch.arange(2001, 2065) + leaf_a = torch.arange(3001, 3013) + leaf_b = torch.arange(4001, 4017) + first = torch.cat((root, mid_a, leaf_a)) + second = torch.cat((root, mid_a, leaf_b)) + third = torch.cat((root, mid_b, torch.tensor([91, 92, 93]))) + return (first, first, second, third, third) + + +def _random_tree_sequences(seed: int, *, max_depth: int) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + next_token = 1 + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def tokens(length: int) -> torch.Tensor: + nonlocal next_token + out = torch.arange(next_token, next_token + length) + next_token += length + 997 + return out + + def segment_length(depth: int) -> int: + choices = (1, 3, 17, 64, 129, 257, 384 if depth == 0 else 96) + return choices[randint(0, len(choices) - 1)] + + def walk(prefix: torch.Tensor, depth: int) -> list[torch.Tensor]: + here = torch.cat((prefix, tokens(segment_length(depth)))) + if depth + 1 >= max_depth: + return [ + torch.cat((here, tokens(randint(1, 17)))) for _ in range(randint(2, 4)) + ] + leaves: list[torch.Tensor] = [] + for _ in range(randint(2, 3)): + leaves.extend(walk(here, depth + 1)) + return leaves + + return tuple(walk(torch.empty(0, dtype=torch.long), 0)) + + def _packed_correctness_cases() -> tuple[GdnPhase0Case, ...]: return ( *default_phase0_cases(conv_width=2), diff --git a/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py b/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py index 19f33970c..fe1159a65 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py @@ -99,63 +99,65 @@ def test_qwen35_full_model_cp1_matches_flattened_grad_accumulation() -> None: spec = parse_gdn_shared_prefix_segments( group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 ) - for family in spec.families: - row = family.row_index - prefix = family.prefix - for completion in family.completions: - ref_tokens = torch.cat( - [ - tokens[row : row + 1, prefix.start : prefix.end], - tokens[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_pos = torch.cat( - [ - input_pos[row : row + 1, prefix.start : prefix.end], - input_pos[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_assistant_mask = torch.cat( - [ - torch.zeros( - (1, prefix.length), dtype=torch.bool, device=device - ), - assistant_mask[ - row : row + 1, completion.start : completion.end - ], - ], - dim=1, - ) - ref_group_ids = torch.zeros_like(ref_tokens) - ref_parent_ids = torch.zeros_like(ref_tokens) - ref_logits, ref_loss = _run_model_loss( - flat_model, - tokens=ref_tokens, - input_pos=ref_pos, - group_ids=ref_group_ids, - parent_ids=ref_parent_ids, - assistant_mask=ref_assistant_mask, - ) - ref_loss.backward() - flat_loss_sum = ( - ref_loss.detach() - if flat_loss_sum is None - else flat_loss_sum + ref_loss.detach() - ) + for segment_index, completion in enumerate(spec.tree_segments): + if spec.tree_parent_indices[segment_index] < 0: + continue + row = completion.row_index + path = _segment_path(spec, segment_index) + completion_offset = sum(segment.length for segment in path[:-1]) + ref_tokens = torch.cat( + [ + tokens[row : row + 1, segment.start : segment.end] + for segment in path + ], + dim=1, + ) + ref_pos = torch.cat( + [ + input_pos[row : row + 1, segment.start : segment.end] + for segment in path + ], + dim=1, + ) + ref_assistant_mask = torch.cat( + [ + torch.zeros( + (1, completion_offset), + dtype=torch.bool, + device=device, + ), + assistant_mask[row : row + 1, completion.start : completion.end], + ], + dim=1, + ) + ref_group_ids = torch.zeros_like(ref_tokens) + ref_parent_ids = torch.zeros_like(ref_tokens) + ref_logits, ref_loss = _run_model_loss( + flat_model, + tokens=ref_tokens, + input_pos=ref_pos, + group_ids=ref_group_ids, + parent_ids=ref_parent_ids, + assistant_mask=ref_assistant_mask, + ) + ref_loss.backward() + flat_loss_sum = ( + ref_loss.detach() + if flat_loss_sum is None + else flat_loss_sum + ref_loss.detach() + ) - if completion.length > 1: - packed_slice = packed_logits[ - row : row + 1, completion.start : completion.end - 1 - ] - ref_slice = ref_logits[ - :, prefix.length : prefix.length + completion.length - 1 - ] - logits_mean_abs_pct = max( - logits_mean_abs_pct, - mean_abs_pct(ref_slice, packed_slice), - ) + if completion.length > 1: + packed_slice = packed_logits[ + row : row + 1, completion.start : completion.end - 1 + ] + ref_slice = ref_logits[ + :, completion_offset : completion_offset + completion.length - 1 + ] + logits_mean_abs_pct = max( + logits_mean_abs_pct, + mean_abs_pct(ref_slice, packed_slice), + ) assert flat_loss_sum is not None grad_name, grad_pct = parameter_grad_mean_abs_pct_with_name( @@ -217,67 +219,63 @@ def _assert_logits_vjp_equivalence( spec = parse_gdn_shared_prefix_segments( group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 ) - for family in spec.families: - row = family.row_index - prefix = family.prefix - for completion in family.completions: - ref_tokens = torch.cat( - [ - tokens[row : row + 1, prefix.start : prefix.end], - tokens[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_pos = torch.cat( - [ - input_pos[row : row + 1, prefix.start : prefix.end], - input_pos[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_logits = _run_model_logits( - flat_model, - tokens=ref_tokens, - input_pos=ref_pos, - group_ids=torch.zeros_like(ref_tokens), - parent_ids=torch.zeros_like(ref_tokens), - ) - ref_output_grad = torch.zeros_like(ref_logits) - ref_output_mask = torch.zeros( - ref_logits.shape[:2], - device=ref_logits.device, - dtype=torch.bool, - ) - if completion.length > 1: - ref_output_grad[ - :, prefix.length : prefix.length + completion.length - 1 - ] = output_grad[row : row + 1, completion.start : completion.end - 1] - ref_output_mask[ - :, prefix.length : prefix.length + completion.length - 1 - ] = True - ref_loss = stable_output_mse_loss( - ref_logits, - ref_output_grad, - mask=ref_output_mask.unsqueeze(-1), - denominator=loss_denominator, - ) - ref_loss.backward() - flat_loss_sum = ( - ref_loss.detach() - if flat_loss_sum is None - else flat_loss_sum + ref_loss.detach() + for segment_index, completion in enumerate(spec.tree_segments): + if spec.tree_parent_indices[segment_index] < 0: + continue + row = completion.row_index + path = _segment_path(spec, segment_index) + completion_offset = sum(segment.length for segment in path[:-1]) + ref_tokens = torch.cat( + [tokens[row : row + 1, segment.start : segment.end] for segment in path], + dim=1, + ) + ref_pos = torch.cat( + [input_pos[row : row + 1, segment.start : segment.end] for segment in path], + dim=1, + ) + ref_logits = _run_model_logits( + flat_model, + tokens=ref_tokens, + input_pos=ref_pos, + group_ids=torch.zeros_like(ref_tokens), + parent_ids=torch.zeros_like(ref_tokens), + ) + ref_output_grad = torch.zeros_like(ref_logits) + ref_output_mask = torch.zeros( + ref_logits.shape[:2], + device=ref_logits.device, + dtype=torch.bool, + ) + if completion.length > 1: + ref_output_grad[ + :, completion_offset : completion_offset + completion.length - 1 + ] = output_grad[row : row + 1, completion.start : completion.end - 1] + ref_output_mask[ + :, completion_offset : completion_offset + completion.length - 1 + ] = True + ref_loss = stable_output_mse_loss( + ref_logits, + ref_output_grad, + mask=ref_output_mask.unsqueeze(-1), + denominator=loss_denominator, + ) + ref_loss.backward() + flat_loss_sum = ( + ref_loss.detach() + if flat_loss_sum is None + else flat_loss_sum + ref_loss.detach() + ) + if completion.length > 1: + packed_slice = packed_logits[ + row : row + 1, completion.start : completion.end - 1 + ] + ref_slice = ref_logits[ + :, completion_offset : completion_offset + completion.length - 1 + ] + logits_mean_abs_pct = max( + logits_mean_abs_pct, + mean_abs_pct(ref_slice, packed_slice), ) - if completion.length > 1: - packed_slice = packed_logits[ - row : row + 1, completion.start : completion.end - 1 - ] - ref_slice = ref_logits[ - :, prefix.length : prefix.length + completion.length - 1 - ] - logits_mean_abs_pct = max( - logits_mean_abs_pct, - mean_abs_pct(ref_slice, packed_slice), - ) assert flat_loss_sum is not None grad_name, grad_pct = parameter_grad_mean_abs_pct_with_name( @@ -359,6 +357,15 @@ def _run_model_logits( return logits +def _segment_path(spec: Any, segment_index: int) -> tuple[Any, ...]: + indices = [] + cursor = segment_index + while cursor >= 0: + indices.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(indices)) + + def _make_matching_models() -> tuple[torch.nn.Module, torch.nn.Module]: model_parallel_cuda_manual_seed(1234) packed = _make_model() diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py index e0d164c56..2148e3053 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py @@ -139,17 +139,9 @@ def _native_gdn_cp_packed_layer_worker( cp_chain_min_tokens_per_rank=16, cp_chain_min_total_tokens=128, cp_chain_min_prefix_only_tokens=128, - # This test is the native chain correctness guard, so force the - # planner onto chain prefix and completion buckets. - planner_chain_bucket_ms=0.0, - planner_chain_token_ms=0.0, - planner_local_bucket_ms=1.0, - planner_local_token_ms=1.0, - cp_chain_min_score_delta_ms=0.0, ), ) - assert plan.chain_prefix_buckets - assert plan.chain_completion_buckets + assert any(plan.tree_chain_buckets_by_depth) hidden, output_grad = _packed_hidden_and_grad(case, cp_size) ref_hidden = hidden.clone().detach().requires_grad_(True) ref_out, _ = run_gdn_layer( diff --git a/tests/integration/megatron/lora/test_dynamic_lora_slots.py b/tests/integration/megatron/lora/test_dynamic_lora_slots.py new file mode 100644 index 000000000..49a7f8224 --- /dev/null +++ b/tests/integration/megatron/lora/test_dynamic_lora_slots.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from contextlib import contextmanager +import os +import socket +from types import SimpleNamespace + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("megatron.core") + +from megatron.core import parallel_state as ps # noqa: E402 +from torch.distributed import destroy_process_group, init_process_group # noqa: E402 + +from art.megatron.lora import LoRA, LoRASlotRef, use_lora_slot # noqa: E402 +from art.megatron.trainer_rank import AdamParams, TrainerRank # noqa: E402 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +def test_dynamic_lora_slots_capture_recompute_context_and_step_independently() -> None: + with _single_rank_model_parallel(): + device = torch.device("cuda") + lora = LoRA( + "dense", + in_features=4, + out_features=5, + rank=2, + alpha=32, + dtype=torch.float32, + device=device, + ) + ref_a = LoRASlotRef("checkpoint", "A") + ref_b = LoRASlotRef("checkpoint", "B") + lora.load_lora_slot( + ref_a, _adapter("dense", rank=1, seed=1), requires_grad=True + ) + lora.load_lora_slot( + ref_b, _adapter("dense", rank=4, seed=2), requires_grad=True + ) + + x = torch.randn(7, 4, device=device) + with use_lora_slot(LoRASlotRef("checkpoint", None)): + assert torch.equal(lora(x), torch.zeros(7, 5, device=device)) + with use_lora_slot(LoRASlotRef("lora", "missing")): + assert torch.equal(lora(x), torch.zeros(7, 5, device=device)) + + slot_a = lora._slot(ref_a) + assert slot_a is not None + with use_lora_slot(ref_a): + actual = lora(x) + expected = (x @ slot_a.A_T) @ slot_a.B_T * slot_a.scale + assert torch.allclose(actual, expected, atol=0, rtol=0) + assert slot_a.rank == 1 + assert slot_a.scale == 32.0 + assert lora._slot(ref_b).scale == 8.0 # type: ignore[union-attr] + + trainer = _trainer_for(lora, device) + with trainer.push_checkpoint("A"): + assert trainer._slot_stack[-1] == ref_a + with trainer.push_lora(None): + assert trainer._slot_stack[-1].name is None + assert trainer._slot_stack[-1] == ref_a + assert trainer._slot_stack == [] + + from megatron.core.tensor_parallel.random import ( + checkpoint as megatron_checkpoint, + ) + from torch.utils.checkpoint import checkpoint as torch_checkpoint + + _assert_checkpoint_recomputes_with(ref_a, ref_b, lora, torch_checkpoint) + _assert_checkpoint_recomputes_with( + ref_a, ref_b, lora, megatron_checkpoint, False + ) + _assert_step_updates_only(ref_a, ref_b, lora, trainer) + + +def _adapter(prefix: str, *, rank: int, seed: int) -> dict[str, torch.Tensor]: + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(seed) + return { + f"{prefix}.lora_A.weight": torch.randn( + rank, 4, generator=generator, device=device + ), + f"{prefix}.lora_B.weight": torch.randn( + 5, rank, generator=generator, device=device + ), + } + + +def _assert_checkpoint_recomputes_with( + expected_ref: LoRASlotRef, + ambient_ref: LoRASlotRef, + lora: LoRA, + checkpoint, + *checkpoint_args, +) -> None: + for param in lora.parameters(): + param.grad = None + x = torch.randn(3, 4, device="cuda", requires_grad=True) + with use_lora_slot(expected_ref): + y = checkpoint(lambda t: lora(t), *checkpoint_args, x) + with use_lora_slot(ambient_ref): + y.sum().backward() + assert lora._slot(expected_ref).A_T.grad is not None # type: ignore[union-attr] + assert lora._slot(ambient_ref).A_T.grad is None # type: ignore[union-attr] + + +def _assert_step_updates_only( + stepped_ref: LoRASlotRef, + frozen_ref: LoRASlotRef, + lora: LoRA, + trainer: TrainerRank, +) -> None: + for param in lora.parameters(): + param.grad = None + with use_lora_slot(stepped_ref): + lora(torch.randn(5, 4, device="cuda")).sum().backward() + before_stepped = [p.detach().clone() for p in lora.lora_slot_params(stepped_ref)] + before_frozen = [p.detach().clone() for p in lora.lora_slot_params(frozen_ref)] + trainer.optim_step( + params=AdamParams(learning_rate=1e-3, weight_decay=0.0, grad_clip_norm=1.0), + checkpoints=[stepped_ref.name or ""], + ) + assert any( + not torch.equal(before, after) + for before, after in zip( + before_stepped, lora.lora_slot_params(stepped_ref), strict=True + ) + ) + assert all( + torch.equal(before, after) + for before, after in zip( + before_frozen, lora.lora_slot_params(frozen_ref), strict=True + ) + ) + + +def _trainer_for(lora: LoRA, device: torch.device) -> TrainerRank: + trainer = TrainerRank.__new__(TrainerRank) + trainer.runtime = SimpleNamespace(model=[lora], optimizer=None) + trainer.device = device + trainer._slot_stack = [] + trainer._default_slot_ref = None + trainer._dynamic_optimizers = {} + trainer._checkpoint_slot_names = {"A", "B"} + return trainer + + +@contextmanager +def _single_rank_model_parallel(): + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = str(_free_port()) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + torch.cuda.set_device(0) + init_process_group("nccl", rank=0, world_size=1) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py index b14cd2a4c..7bb3e1b94 100644 --- a/tests/integration/megatron/lora/test_lora_disk_codecs.py +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -1,12 +1,17 @@ import json +import os from pathlib import Path +import shutil import subprocess import sys from typing import Any, cast +import pytest from safetensors.torch import load_file, save_file import torch +pytest.importorskip("megatron.bridge.models.gpt_provider") + from art.megatron import lora as lora_module from art.megatron.lora import LoRA, LoRAParallelSpec, LoRAPublishPlanner from art.megatron.model_support.handlers import ( @@ -29,6 +34,66 @@ REPO_ROOT = Path(__file__).parents[4] VLLM_PYTHON = REPO_ROOT / "vllm_runtime/.venv/bin/python" +_VLLM_RUNTIME_UNAVAILABLE_REASON: str | None | object = object() + + +def _vllm_python_cmd() -> list[str]: + override = os.environ.get("ART_TEST_VLLM_PYTHON") + if override: + return [override] + if VLLM_PYTHON.exists(): + return [str(VLLM_PYTHON)] + uv = shutil.which("uv") + if uv is None: + raise RuntimeError( + f"{VLLM_PYTHON} does not exist and uv is not available to run " + "the locked vLLM runtime project" + ) + return [ + uv, + "run", + "--project", + str(REPO_ROOT / "vllm_runtime"), + "--frozen", + "--no-dev", + "python", + ] + + +def _vllm_runtime_unavailable_reason() -> str | None: + global _VLLM_RUNTIME_UNAVAILABLE_REASON + if isinstance(_VLLM_RUNTIME_UNAVAILABLE_REASON, str): + return _VLLM_RUNTIME_UNAVAILABLE_REASON + if _VLLM_RUNTIME_UNAVAILABLE_REASON is None: + return None + try: + subprocess.run( + [ + *_vllm_python_cmd(), + "-c", + "import vllm; from vllm.lora.lora_model import LoRAModel", + ], + check=True, + text=True, + capture_output=True, + timeout=120, + ) + except Exception as exc: + _VLLM_RUNTIME_UNAVAILABLE_REASON = ( + "Stock vLLM loader runtime is unavailable. Run " + "`uv sync --project vllm_runtime --frozen --no-dev`, or set " + "`ART_TEST_VLLM_PYTHON` to a Python environment with vLLM installed. " + f"Original error: {exc}" + ) + return _VLLM_RUNTIME_UNAVAILABLE_REASON + _VLLM_RUNTIME_UNAVAILABLE_REASON = None + return None + + +def test_stock_vllm_loader_runtime_is_available() -> None: + reason = _vllm_runtime_unavailable_reason() + if reason is not None: + pytest.fail(reason) def _config(base_model: str, rank: int = 2, alpha: int = 4) -> dict: @@ -116,6 +181,8 @@ def _assert_stock_vllm_loads( expected_modules: set[str], mapper: str = "none", ) -> list[str]: + if reason := _vllm_runtime_unavailable_reason(): + pytest.skip(reason) script = r""" import json import sys @@ -142,7 +209,7 @@ def _assert_stock_vllm_loads( """ result = subprocess.run( [ - str(VLLM_PYTHON), + *_vllm_python_cmd(), "-c", script, str(path), diff --git a/tests/unit/test_shared_prefix_attention_builder.py b/tests/unit/test_shared_prefix_attention_builder.py new file mode 100644 index 000000000..1214d344e --- /dev/null +++ b/tests/unit/test_shared_prefix_attention_builder.py @@ -0,0 +1,566 @@ +from __future__ import annotations + +import pytest +import torch +from torch.nn.attention.flex_attention import BlockMask +from torch.nn.attention.flex_attention import create_block_mask as torch_block_mask + +pytest.importorskip("megatron.core.packed_seq_params") + +from art.megatron.context_parallel.block_mask import build_block_mask +from art.megatron.context_parallel.builder import ( + build_dense_reference_mask, + build_shared_prefix_attention_spec, +) +from art.megatron.context_parallel.runtime import ( + build_context_parallel_token_layout_index, + get_or_build_runtime_plan, + make_runtime_key, +) +from art.megatron.context_parallel.types import ( + AttnMaskKind, + AttnSlice, + ContextParallelConfig, + ExactMaskMetadata, + FlexMaskSpec, + ParallelTopology, + TokenRange, +) +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.megatron.shared_prefix_state import create_shared_prefix_state + + +def test_shared_prefix_attention_spec_supports_branching_completions() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + dense = build_dense_reference_mask(row_spec=spec.rows[0]) + + assert dense.int().tolist() == [ + [1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 1, 0], + [1, 1, 1, 0, 0, 1, 1], + ] + + +def test_shared_prefix_attention_spec_matches_tree_reference() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + dense = build_dense_reference_mask(row_spec=spec.rows[0]) + + assert dense.equal(_reference_tree_mask(group_ids[0], parent_ids[0])) + + +def test_shared_prefix_can_build_context_parallel_layout() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + layout = build_context_parallel_token_layout_index( + group_ids=group_ids, + parent_ids=parent_ids, + topology=ParallelTopology(cp=2), + config=ContextParallelConfig(planner_chunk_size=2, planner_max_search_steps=1), + original_seq_len=int(group_ids.numel()), + ) + + assert sum(layout.token_counts_by_rank) == int(group_ids.numel()) + + +def test_sparse_block_mask_exact_predicate_matches_dense_reference() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + row = spec.rows[0] + token_indices = torch.arange(row.valid_tokens, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=row.valid_tokens, + k_len=row.valid_tokens, + block_size=(2, 2), + slices=row.slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key="depth-two", + ), + ), + group_ids=group_ids[0], + parent_ids=parent_ids[0], + device=torch.device("cpu"), + ) + + assert block_mask is not None + q_indices = torch.arange(row.valid_tokens)[:, None] + k_indices = torch.arange(row.valid_tokens)[None, :] + actual = block_mask.mask_mod( + torch.zeros_like(q_indices), + torch.zeros_like(q_indices), + q_indices, + k_indices, + ) + + assert actual.equal(build_dense_reference_mask(row_spec=row)) + + +@pytest.mark.parametrize( + ("name", "pack"), + ( + ( + "no-sharing", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3]), + torch.tensor([4, 5]), + torch.tensor([6, 7, 8, 9]), + ), + max_depth=0, + ), + ), + ( + "depth-one", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 2, 6]), + ), + max_depth=1, + ), + ), + ( + "depth-three", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ), + ), + ), +) +def test_sparse_block_mask_matches_torch_block_metadata( + name: str, + pack: SharedPrefixPack, +) -> None: + del name + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + row = spec.rows[0] + token_indices = torch.arange(row.valid_tokens, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=row.valid_tokens, + k_len=row.valid_tokens, + block_size=(2, 2), + slices=row.slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key="torch-parity", + ), + ), + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + device=torch.device("cpu"), + ) + + assert block_mask is not None + _assert_matches_torch_block_mask(block_mask) + + +def test_sparse_block_mask_prunes_exact_blocks_rejected_by_group_tree() -> None: + group_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.long) + parent_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=4, + k_len=4, + block_size=(2, 2), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=4), + k_range=TokenRange(start=0, end=4), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=torch.tensor([4, 5, 6, 7], dtype=torch.long), + k_token_indices=torch.tensor([0, 1, 2, 3], dtype=torch.long), + cache_key="all-false-cross-family", + ), + ), + group_ids=group_ids, + parent_ids=parent_ids, + device=torch.device("cpu"), + ) + + assert block_mask is not None + assert int(block_mask.kv_num_blocks.sum().item()) == 0 + assert int(block_mask.full_kv_num_blocks.sum().item()) == 0 + _assert_matches_torch_block_mask(block_mask) + + +def test_shared_prefix_state_builds_batched_block_mask() -> None: + group_ids = torch.tensor( + [ + [1, 1, 2, 2, -1], + [10, 11, 11, -1, -1], + ], + dtype=torch.long, + ) + parent_ids = torch.tensor( + [ + [1, 1, 1, 1, -1], + [10, 10, 10, -1, -1], + ], + dtype=torch.long, + ) + + state = create_shared_prefix_state( + group_ids=group_ids, + parent_ids=parent_ids, + target_device=torch.device("cpu"), + ) + seq_len = int(group_ids.shape[1]) + batch_idx = torch.arange(2)[:, None, None].expand(2, seq_len, seq_len) + query_idx = torch.arange(seq_len)[None, :, None].expand(2, seq_len, seq_len) + kv_idx = torch.arange(seq_len)[None, None, :].expand(2, seq_len, seq_len) + actual = state.block_mask.mask_mod( + batch_idx, + torch.zeros_like(batch_idx), + query_idx, + kv_idx, + ) + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + assert int(state.block_mask.kv_num_blocks.shape[0]) == 2 + for row_index, row_spec in enumerate(spec.rows): + valid_tokens = int(row_spec.valid_tokens) + assert actual[ + row_index, + :valid_tokens, + :valid_tokens, + ].equal(build_dense_reference_mask(row_spec=row_spec)) + _assert_matches_torch_block_mask(state.block_mask, batch_size=2) + + +def test_context_parallel_stage_masks_match_dense_nested_tree() -> None: + _assert_context_parallel_stage_masks_match_dense( + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ), + require_remote_stage=True, + ) + _assert_context_parallel_stage_masks_match_dense( + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + torch.tensor([7, 8]), + torch.tensor([9, 10, 11, 12]), + ), + max_depth=3, + ), + require_remote_stage=False, + ) + + +def _assert_context_parallel_stage_masks_match_dense( + pack: SharedPrefixPack, + *, + require_remote_stage: bool, +) -> None: + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + row = spec.rows[0] + dense = build_dense_reference_mask(row_spec=row) + topology = ParallelTopology(cp=2) + config = ContextParallelConfig( + block_size=2, + planner_chunk_size=2, + planner_max_search_steps=1, + planner_remote_stage_token_floor=1, + planner_remote_stage_pair_floor=1, + ) + plan = get_or_build_runtime_plan( + spec, + topology=topology, + config=config, + runtime_key=make_runtime_key(spec, topology=topology, config=config), + original_seq_len=int(pack.tokens.numel()), + ) + + checked_stages = 0 + checked_remote_stages = 0 + for rank_plan in plan.rank_plans: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + block_mask = build_block_mask( + FlexMaskSpec( + q_len=stage.q_len, + k_len=stage.k_len, + block_size=(2, 2), + slices=stage.slices, + exact_mask=stage.mask_metadata, + ), + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + device=torch.device("cpu"), + ) + assert block_mask is not None + q_offsets = torch.arange(stage.q_len)[:, None] + k_offsets = torch.arange(stage.k_len)[None, :] + actual = block_mask.mask_mod( + torch.zeros_like(q_offsets), + torch.zeros_like(q_offsets), + q_offsets, + k_offsets, + ) + q_tokens = stage.mask_metadata.q_token_indices + k_tokens = stage.mask_metadata.k_token_indices + expected = ( + dense[q_tokens.clamp_min(0)[:, None], k_tokens.clamp_min(0)[None, :]] + & (q_tokens[:, None] >= 0) + & (k_tokens[None, :] >= 0) + ) + + assert actual.equal(expected) + assert _effective_block_mask(block_mask).equal(expected) + _assert_matches_torch_block_mask(block_mask) + checked_stages += 1 + checked_remote_stages += int(not stage.is_local_stage) + + assert checked_stages + if require_remote_stage: + assert checked_remote_stages + + +def _effective_block_mask(block_mask: BlockMask) -> torch.Tensor: + q_len, k_len = block_mask.seq_lengths + q_block, k_block = block_mask.BLOCK_SIZE + effective = torch.zeros((q_len, k_len), dtype=torch.bool) + _fill_full_blocks(effective, block_mask, q_block=q_block, k_block=k_block) + _fill_partial_blocks(effective, block_mask, q_block=q_block, k_block=k_block) + return effective + + +def _fill_full_blocks( + effective: torch.Tensor, + block_mask: BlockMask, + *, + q_block: int, + k_block: int, +) -> None: + if block_mask.full_kv_num_blocks is None or block_mask.full_kv_indices is None: + return + for q_block_index in range(int(block_mask.full_kv_num_blocks.shape[-1])): + q_slice = slice(q_block_index * q_block, (q_block_index + 1) * q_block) + block_count = int(block_mask.full_kv_num_blocks[0, 0, q_block_index]) + for k_block_index in block_mask.full_kv_indices[ + 0, 0, q_block_index, :block_count + ].tolist(): + k_slice = slice( + int(k_block_index) * k_block, + (int(k_block_index) + 1) * k_block, + ) + effective[q_slice, k_slice] = True + + +def _fill_partial_blocks( + effective: torch.Tensor, + block_mask: BlockMask, + *, + q_block: int, + k_block: int, +) -> None: + for q_block_index in range(int(block_mask.kv_num_blocks.shape[-1])): + q_offsets = torch.arange( + q_block_index * q_block, + min((q_block_index + 1) * q_block, effective.shape[0]), + )[:, None] + block_count = int(block_mask.kv_num_blocks[0, 0, q_block_index]) + for k_block_index in block_mask.kv_indices[ + 0, 0, q_block_index, :block_count + ].tolist(): + k_offsets = torch.arange( + int(k_block_index) * k_block, + min((int(k_block_index) + 1) * k_block, effective.shape[1]), + )[None, :] + effective[q_offsets, k_offsets] |= block_mask.mask_mod( + torch.zeros_like(q_offsets), + torch.zeros_like(q_offsets), + q_offsets, + k_offsets, + ) + + +def test_sparse_block_mask_supports_non_monotonic_remote_k_indices() -> None: + q_token_indices = torch.tensor([4, 5, 6, 7], dtype=torch.long) + k_token_indices = torch.tensor([0, 1, 6, 2, 3, 4], dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=int(q_token_indices.numel()), + k_len=int(k_token_indices.numel()), + block_size=(2, 2), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=int(q_token_indices.numel())), + k_range=TokenRange(start=0, end=int(k_token_indices.numel())), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=q_token_indices, + k_token_indices=k_token_indices, + cache_key="non-monotonic-k", + ), + ), + group_ids=torch.ones(8, dtype=torch.long), + parent_ids=torch.ones(8, dtype=torch.long), + device=torch.device("cpu"), + ) + + assert block_mask is not None + q_indices = torch.arange(q_token_indices.numel())[:, None] + k_indices = torch.arange(k_token_indices.numel())[None, :] + + actual = block_mask.mask_mod( + torch.zeros_like(q_indices), + torch.zeros_like(q_indices), + q_indices, + k_indices, + ) + + assert actual.equal(q_token_indices[:, None] >= k_token_indices[None, :]) + _assert_matches_torch_block_mask(block_mask) + + +def _assert_matches_torch_block_mask( + block_mask: BlockMask, + *, + batch_size: int = 1, +) -> None: + q_len, k_len = block_mask.seq_lengths + reference = torch_block_mask( + block_mask.mask_mod, + B=batch_size, + H=1, + Q_LEN=q_len, + KV_LEN=k_len, + device="cpu", + BLOCK_SIZE=block_mask.BLOCK_SIZE, + ) + assert _effective_block_mask(block_mask).equal(_effective_block_mask(reference)) + for counts_name, indices_name in ( + ("kv_num_blocks", "kv_indices"), + ("full_kv_num_blocks", "full_kv_indices"), + ("q_num_blocks", "q_indices"), + ("full_q_num_blocks", "full_q_indices"), + ): + assert _block_entries(block_mask, counts_name, indices_name) == _block_entries( + reference, + counts_name, + indices_name, + ) + + +def _block_entries( + block_mask: BlockMask, + counts_name: str, + indices_name: str, +) -> set[tuple[int, int, int, int]]: + counts = getattr(block_mask, counts_name) + indices = getattr(block_mask, indices_name) + if counts is None or indices is None: + return set() + entries = set() + for batch_index in range(int(counts.shape[0])): + for head_index in range(int(counts.shape[1])): + for block_index in range(int(counts.shape[2])): + block_count = int(counts[batch_index, head_index, block_index]) + for other_block in indices[ + batch_index, + head_index, + block_index, + :block_count, + ].tolist(): + entries.add( + ( + batch_index, + head_index, + block_index, + int(other_block), + ) + ) + return entries + + +def _branching_prefix_inputs() -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.tensor([[1, 1, 1, 2, 3, 4, 4]], dtype=torch.long), + torch.tensor([[1, 1, 1, 1, 1, 1, 1]], dtype=torch.long), + ) + + +def _reference_tree_mask( + group_ids: torch.Tensor, parent_ids: torch.Tensor +) -> torch.Tensor: + group_list = [int(value) for value in group_ids.tolist()] + parent_by_group: dict[int, int | None] = {} + for group_id, parent_id in zip(group_list, parent_ids.tolist(), strict=True): + group_id = int(group_id) + parent_id = int(parent_id) + if group_id not in parent_by_group: + parent_by_group[group_id] = None if parent_id == group_id else parent_id + + ancestors_by_group = { + group_id: _ancestors(group_id, parent_by_group) for group_id in parent_by_group + } + dense = torch.zeros((len(group_list), len(group_list)), dtype=torch.bool) + for q_pos, q_group in enumerate(group_list): + allowed_groups = ancestors_by_group[q_group] | {q_group} + for k_pos, k_group in enumerate(group_list): + dense[q_pos, k_pos] = k_pos <= q_pos and k_group in allowed_groups + return dense + + +def _ancestors( + group_id: int, + parent_by_group: dict[int, int | None], +) -> set[int]: + ancestors: set[int] = set() + cursor = parent_by_group[group_id] + while cursor is not None and cursor not in ancestors: + ancestors.add(cursor) + cursor = parent_by_group.get(cursor) + return ancestors diff --git a/tests/unit/test_shared_prefix_grad_parity.py b/tests/unit/test_shared_prefix_grad_parity.py new file mode 100644 index 000000000..5b812782b --- /dev/null +++ b/tests/unit/test_shared_prefix_grad_parity.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from copy import deepcopy + +import pytest +import torch +from torch import nn +import torch.nn.functional as F + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes + + +class _ToyCausalLM(nn.Module): + def __init__(self) -> None: + super().__init__() + self.token_embedding = nn.Embedding(32, 8, dtype=torch.float64) + self.position_embedding = nn.Embedding(8, 8, dtype=torch.float64) + self.mix = nn.Linear(8, 8, bias=False, dtype=torch.float64) + self.output = nn.Linear(8, 32, bias=False, dtype=torch.float64) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + causal_mask: torch.Tensor, + ) -> torch.Tensor: + states = self.token_embedding(input_ids) + self.position_embedding(position_ids) + context = causal_mask.to(states.dtype) @ states + return self.output(torch.tanh(self.mix(context))) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +@pytest.mark.parametrize("multi_target", (False, True)) +def test_shared_prefix_ce_parameter_grads_match_independent_sequences( + *, + max_depth: int, + multi_target: bool, +) -> None: + input_ids = _input_ids() + target_ids = tuple( + _targets(tokens, multi_target=multi_target) for tokens in input_ids + ) + pack = pack_shared_prefixes(input_ids, max_depth=max_depth) + + assert int(pack.tokens.numel()) < sum(len(row) for row in input_ids) + + torch.manual_seed(20260518) + naive_model = _ToyCausalLM() + packed_model = deepcopy(naive_model) + + naive_loss = torch.stack( + [ + _sequence_ce_loss(naive_model, tokens, labels) + for tokens, labels in zip(input_ids, target_ids, strict=True) + ] + ).sum() + packed_loss = _packed_ce_loss(packed_model, pack, target_ids) + + torch.testing.assert_close(packed_loss, naive_loss, rtol=1e-12, atol=1e-12) + naive_loss.backward() + packed_loss.backward() + + for (name, naive_param), packed_param in zip( + naive_model.named_parameters(), + packed_model.parameters(), + strict=True, + ): + assert naive_param.grad is not None, name + assert packed_param.grad is not None, name + torch.testing.assert_close( + packed_param.grad, + naive_param.grad, + rtol=1e-10, + atol=1e-10, + msg=lambda msg, name=name: f"{name} grad mismatch:\n{msg}", + ) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +def test_same_layout_mutation_preserves_forward_outputs(max_depth: int) -> None: + pack = pack_shared_prefixes(_input_ids(), max_depth=max_depth) + torch.manual_seed(20260518) + model = _ToyCausalLM() + logits = _packed_logits(model, pack) + + for positions in pack.positions_by_sequence: + mutated_logits = _packed_logits(model, _mutated_pack(pack, keep=positions)) + torch.testing.assert_close( + mutated_logits.index_select(0, positions), + logits.index_select(0, positions), + rtol=0.0, + atol=0.0, + ) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +@pytest.mark.parametrize("sequence_index", (0, 2, 4)) +def test_same_layout_mutation_preserves_target_loss_grads( + max_depth: int, + sequence_index: int, +) -> None: + input_ids = _input_ids() + target_ids = tuple(_targets(tokens, multi_target=True) for tokens in input_ids) + pack = pack_shared_prefixes(input_ids, max_depth=max_depth) + mutated = _mutated_pack(pack, keep=pack.positions_by_sequence[sequence_index]) + + torch.manual_seed(20260518) + base_model = _ToyCausalLM() + mutated_model = deepcopy(base_model) + + base_loss = _packed_sequence_ce_loss(base_model, pack, target_ids, sequence_index) + mutated_loss = _packed_sequence_ce_loss( + mutated_model, + mutated, + target_ids, + sequence_index, + ) + + torch.testing.assert_close(mutated_loss, base_loss, rtol=0.0, atol=0.0) + base_loss.backward() + mutated_loss.backward() + _assert_matching_grads(mutated_model, base_model) + + +def _input_ids() -> tuple[torch.Tensor, ...]: + return ( + torch.tensor([1, 2, 3, 4, 5]), + torch.tensor([1, 2, 3, 4, 6]), + torch.tensor([1, 2, 3, 7]), + torch.tensor([1, 2, 8]), + torch.tensor([9, 10, 11]), + ) + + +def _targets(tokens: torch.Tensor, *, multi_target: bool) -> torch.Tensor: + labels = (tokens * 3 + 5) % 31 + if not multi_target: + return labels + alternate = (tokens * 5 + 7) % 31 + stacked = torch.stack((labels, alternate), dim=1) + if int(stacked.numel()) > 2: + stacked[1, 1] = -100 + return stacked + + +def _sequence_ce_loss( + model: _ToyCausalLM, + input_ids: torch.Tensor, + target_ids: torch.Tensor, +) -> torch.Tensor: + seq_len = int(input_ids.numel()) + logits = model( + input_ids, + torch.arange(seq_len), + torch.ones((seq_len, seq_len), dtype=torch.bool).tril(), + ) + return _target_ce_loss(logits, target_ids) + + +def _packed_ce_loss( + model: _ToyCausalLM, + pack: SharedPrefixPack, + target_ids: tuple[torch.Tensor, ...], +) -> torch.Tensor: + logits = _packed_logits(model, pack) + losses = [ + _target_ce_loss(logits.index_select(0, positions), labels) + for positions, labels in zip( + pack.positions_by_sequence, + target_ids, + strict=True, + ) + ] + return torch.stack(losses).sum() + + +def _packed_sequence_ce_loss( + model: _ToyCausalLM, + pack: SharedPrefixPack, + target_ids: tuple[torch.Tensor, ...], + sequence_index: int, +) -> torch.Tensor: + return _target_ce_loss( + _packed_logits(model, pack).index_select( + 0, + pack.positions_by_sequence[sequence_index], + ), + target_ids[sequence_index], + ) + + +def _packed_logits(model: _ToyCausalLM, pack: SharedPrefixPack) -> torch.Tensor: + return model( + pack.tokens.reshape(-1), + pack.position_ids.reshape(-1), + _shared_prefix_causal_mask(pack), + ) + + +def _target_ce_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if labels.ndim == 1: + return F.cross_entropy(logits, labels, ignore_index=-100, reduction="sum") + expanded = logits.unsqueeze(1).expand(-1, int(labels.shape[1]), -1) + return F.cross_entropy( + expanded.reshape(-1, int(logits.shape[-1])), + labels.reshape(-1), + ignore_index=-100, + reduction="sum", + ) + + +def _mutated_pack(pack: SharedPrefixPack, *, keep: torch.Tensor) -> SharedPrefixPack: + tokens = pack.tokens.clone() + mutate = torch.ones(int(tokens.shape[1]), dtype=torch.bool) + mutate[keep] = False + replacement = torch.arange(int(tokens.shape[1]), dtype=tokens.dtype) + 17 + tokens[0, mutate] = replacement[mutate] % 31 + return SharedPrefixPack( + tokens=tokens, + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + position_ids=pack.position_ids, + positions_by_sequence=pack.positions_by_sequence, + ) + + +def _assert_matching_grads(actual_model: nn.Module, expected_model: nn.Module) -> None: + for (name, expected_param), actual_param in zip( + expected_model.named_parameters(), + actual_model.parameters(), + strict=True, + ): + assert expected_param.grad is not None, name + assert actual_param.grad is not None, name + torch.testing.assert_close( + actual_param.grad, + expected_param.grad, + rtol=1e-10, + atol=1e-10, + msg=lambda msg, name=name: f"{name} grad mismatch:\n{msg}", + ) + + +def _shared_prefix_causal_mask(pack: SharedPrefixPack) -> torch.Tensor: + group_ids = pack.group_ids.reshape(-1).tolist() + parent_ids = pack.parent_ids.reshape(-1).tolist() + position_ids = pack.position_ids.reshape(-1).tolist() + parent_by_group: dict[int, int] = {} + for group_id, parent_id in zip(group_ids, parent_ids, strict=True): + previous = parent_by_group.setdefault(group_id, parent_id) + assert previous == parent_id + + ancestors = { + group_id: _ancestor_groups(group_id, parent_by_group) + for group_id in parent_by_group + } + mask = torch.zeros((len(group_ids), len(group_ids)), dtype=torch.bool) + for query_index, query_group in enumerate(group_ids): + query_ancestors = ancestors[query_group] + query_position = position_ids[query_index] + for key_index, key_group in enumerate(group_ids): + if ( + key_group in query_ancestors + and position_ids[key_index] <= query_position + ): + mask[query_index, key_index] = True + return mask + + +def _ancestor_groups(group_id: int, parent_by_group: dict[int, int]) -> set[int]: + ancestors = {group_id} + parent_id = parent_by_group[group_id] + while parent_id != group_id: + if parent_id in ancestors: + raise AssertionError("shared-prefix group parents contain a cycle") + ancestors.add(parent_id) + group_id = parent_id + parent_id = parent_by_group[group_id] + return ancestors diff --git a/tests/unit/test_shared_prefix_packing.py b/tests/unit/test_shared_prefix_packing.py new file mode 100644 index 000000000..d1c17d7d8 --- /dev/null +++ b/tests/unit/test_shared_prefix_packing.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import pytest +import torch + +from art.megatron.shared_prefix_packing import ( + pack_shared_prefixes, + visualize_shared_prefix_pack, +) +from art.megatron.trainer_rank import _local_position_pairs + + +def test_pack_shared_prefixes_support_depth_one() -> None: + inputs = ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 5]), + torch.tensor([9]), + ) + + pack = pack_shared_prefixes(inputs, max_depth=1) + + assert pack.tokens.tolist() == [[1, 2, 3, 4, 5, 9]] + assert pack.group_ids.tolist() == [[1, 1, 2, 2, 3, 4]] + assert pack.parent_ids.tolist() == [[1, 1, 1, 1, 1, 4]] + assert pack.position_ids.tolist() == [[0, 1, 2, 3, 2, 0]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1, 2, 3], + [0, 1, 4], + [5], + ] + + +def test_pack_shared_prefixes_support_zero_depth_without_sharing() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2]), + torch.tensor([1, 3]), + torch.tensor([4]), + ), + max_depth=0, + ) + + assert pack.tokens.tolist() == [[1, 2, 1, 3, 4]] + assert pack.group_ids.tolist() == [[1, 1, 2, 2, 3]] + assert pack.parent_ids.tolist() == [[1, 1, 2, 2, 3]] + assert pack.position_ids.tolist() == [[0, 1, 0, 1, 0]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1], + [2, 3], + [4], + ] + + +def test_pack_shared_prefixes_support_deeper_trees() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6, 7]), + ), + max_depth=2, + ) + + assert pack.tokens.tolist() == [[1, 2, 3, 4, 5, 6, 7]] + assert pack.group_ids.tolist() == [[1, 2, 2, 3, 4, 5, 5]] + assert pack.parent_ids.tolist() == [[1, 1, 1, 2, 2, 1, 1]] + assert pack.position_ids.tolist() == [[0, 1, 2, 3, 3, 1, 2]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1, 2, 3], + [0, 1, 2, 4], + [0, 5, 6], + ] + + +def test_packing_preserves_first_seen_branch_order() -> None: + pack = pack_shared_prefixes( + (torch.tensor([9]), torch.tensor([1])), + max_depth=1, + ) + + assert pack.tokens.tolist() == [[9, 1]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0], + [1], + ] + + +def test_packing_handles_empty_sequences() -> None: + pack = pack_shared_prefixes( + (torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)), + max_depth=1, + ) + + assert pack.tokens.tolist() == [[]] + assert pack.group_ids.tolist() == [[]] + assert pack.parent_ids.tolist() == [[]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [[], []] + + +def test_packing_rejects_non_1d_sequences() -> None: + with pytest.raises(ValueError, match="expects 1-D tensors"): + pack_shared_prefixes((torch.tensor([[1, 2], [3, 4]]),), max_depth=1) + + +def test_visualization_includes_reverse_index() -> None: + pack = pack_shared_prefixes( + (torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4])), + max_depth=1, + ) + + visualization = visualize_shared_prefix_pack(pack) + + assert visualization.splitlines()[0] == "pos token group parent source_pos" + assert "seq 1: [0, 1, 3]" in visualization + + +def test_local_position_pairs_preserve_requested_order_without_dense_match() -> None: + local_global_positions = torch.tensor([[2, -1, 0, 4, 1]]) + item_positions = torch.tensor([0, 1, 2, 3, 4]) + + local_positions, source_positions = _local_position_pairs( + local_global_positions, + item_positions, + ) + + assert local_positions.tolist() == [2, 4, 0, 3] + assert source_positions.tolist() == [0, 1, 2, 4] diff --git a/tests/unit/test_shared_prefix_tree.py b/tests/unit/test_shared_prefix_tree.py new file mode 100644 index 000000000..57cc9fa5c --- /dev/null +++ b/tests/unit/test_shared_prefix_tree.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +import pytest +import torch + +from art.megatron.shared_prefix_packing import pack_shared_prefixes +from art.megatron.shared_prefix_tree import ( + max_shared_prefix_tree_depth, + parse_shared_prefix_row, +) + + +def test_parse_shared_prefix_row_tracks_ancestors_and_depth() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ) + + tree = parse_shared_prefix_row( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + + assert tree.valid_tokens == int(pack.tokens.numel()) + assert tree.max_depth == 3 + assert [(segment.group_id, segment.ancestors) for segment in tree.segments] == [ + (1, ()), + (2, (1,)), + (3, (1, 2)), + (4, (1, 2, 3)), + (5, (1, 2, 3)), + (6, (1, 2)), + (7, (1,)), + ] + + +def test_parse_shared_prefix_row_rejects_missing_parent() -> None: + with pytest.raises(RuntimeError, match="missing parent"): + parse_shared_prefix_row( + group_ids=torch.tensor([1, 2]), + parent_ids=torch.tensor([1, 3]), + ) + + +def test_parse_shared_prefix_row_rejects_non_contiguous_group() -> None: + with pytest.raises(RuntimeError, match="contiguous group runs"): + parse_shared_prefix_row( + group_ids=torch.tensor([1, 2, 1]), + parent_ids=torch.tensor([1, 1, 1]), + ) + + +def test_max_shared_prefix_tree_depth_treats_flat_families_as_depth_one() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 5]), + torch.tensor([9]), + ), + max_depth=1, + ) + + assert ( + max_shared_prefix_tree_depth( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + == 1 + ) + + +def test_gdn_tree_parser_accepts_nested_tree() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=2, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan(spec, device="cpu") + + assert spec.tree_parent_indices == (-1, 0, 1, 1, 0) + assert spec.tree_depths == (0, 1, 2, 2, 1) + assert [ + sum(bucket.segment_count for bucket in buckets) + for buckets in plan.tree_segment_buckets_by_depth + ] == [1, 2, 2] + + +def test_gdn_tree_parser_accepts_zero_depth_roots() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2]), + torch.tensor([1, 3]), + torch.tensor([4]), + ), + max_depth=0, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan(spec, device="cpu") + + assert spec.tree_parent_indices == (-1, -1, -1) + assert spec.tree_depths == (0, 0, 0) + assert [bucket.segment_count for bucket in plan.tree_segment_buckets_by_depth[0]] + assert not hasattr(plan, "local_prefix_buckets") + assert not hasattr(plan, "chain_completion_buckets") + assert not hasattr(plan, "prefix_boundary_buckets") + assert all( + not bucket.needs_final_state for bucket in plan.tree_segment_buckets_by_depth[0] + ) + + +def test_gdn_tree_planner_splits_leaf_and_internal_final_state_buckets() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 7]), + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 5, 6]), + ), + max_depth=2, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan( + spec, + device="cpu", + planner_config=GdnPlannerConfig(max_padding_ratio=4.0), + ) + tree_has_children = _tree_has_children(spec) + + depth_one_buckets = plan.tree_segment_buckets_by_depth[1] + assert any(bucket.needs_final_state for bucket in depth_one_buckets) + assert any(not bucket.needs_final_state for bucket in depth_one_buckets) + for bucket in depth_one_buckets: + expected = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + assert expected == {bucket.needs_final_state} + + +def test_gdn_tree_cp_plan_chains_long_nodes() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + root = torch.arange(1, 321) + mid = torch.arange(1001, 1321) + other = torch.arange(2001, 2321) + pack = pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([11]))), + torch.cat((root, mid, torch.tensor([12]))), + torch.cat((root, other, torch.tensor([13]))), + ), + max_depth=3, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + config = _chain_every_legal_segment_config() + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert any(plans[0].tree_chain_buckets_by_depth[0]) + assert not any( + bucket + for plan in plans + for depth_buckets in plan.tree_chain_buckets_by_depth[1:] + for bucket in depth_buckets + ) + _assert_remote_parent_state_transfers_cover(spec, plans) + for plan in plans: + assert sum(plan.gdn_token_count for plan in plans) == spec.real_token_count + for depth_buckets in plan.tree_chain_buckets_by_depth: + for bucket in depth_buckets: + assert bucket.lengths_by_rank_cpu is not None + assert tuple(bucket.lengths_by_rank_cpu.shape)[0] == 4 + assert bucket.parent_indices is not None + + +def test_gdn_tree_cp_plan_exchanges_remote_parent_states() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + root = torch.arange(1, 17) + mid = torch.arange(1001, 1321) + pack = pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([11]))), + torch.cat((root, mid, torch.tensor([12]))), + torch.cat((root, torch.tensor([99]))), + ), + max_depth=2, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=_chain_every_legal_segment_config(), + ) + for rank in range(4) + ) + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert not any( + bucket + for plan in plans + for depth_buckets in plan.tree_chain_buckets_by_depth[1:] + for bucket in depth_buckets + ) + assert _remote_parent_state_transfer_count(plans) > 0 + _assert_remote_parent_state_transfers_cover(spec, plans) + + +def test_gdn_tree_cp_randomized_plans_cover_each_token_once() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + config = _chain_every_legal_segment_config() + for seed in range(8): + pack = pack_shared_prefixes( + _random_tree_sequences(seed), + max_depth=4, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert sum(plan.gdn_token_count for plan in plans) == spec.real_token_count + for plan in plans: + for depth_buckets in ( + *plan.tree_segment_buckets_by_depth, + *plan.tree_chain_buckets_by_depth, + ): + for bucket in depth_buckets: + assert bucket.parent_indices is not None + assert int(bucket.real_token_count) > 0 + + +def test_gdn_tree_cp_randomized_plans_pass_health_checks() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + config = GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=64, + cp_chain_min_prefix_only_tokens=64, + cp_tree_chain_min_total_tokens=64, + cp_tree_chain_min_prefix_only_tokens=64, + max_padding_ratio=4.0, + ) + for seed in range(16): + pack = pack_shared_prefixes( + _random_tree_sequences(seed + 100, max_depth=5), + max_depth=5, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + _assert_tree_plan_health( + spec, plans, max_padding_ratio=config.max_padding_ratio + ) + + +def _chain_every_legal_segment_config(): + from art.megatron.gdn.gdn_shared_prefix import GdnPlannerConfig + + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=1, + cp_chain_min_prefix_only_tokens=1, + max_padding_ratio=4.0, + ) + + +def _covered_token_indices(plans) -> set[int]: + return { + token + for plan in plans + for start, end, _position in plan.gdn_token_ranges + for token in range(start, end) + } + + +def _local_owner_by_family(plans) -> dict[int, int]: + owner_by_family = {} + for rank, plan in enumerate(plans): + for depth_buckets in plan.tree_segment_buckets_by_depth: + for bucket in depth_buckets: + for family_index in bucket.family_indices.tolist(): + previous = owner_by_family.setdefault(int(family_index), rank) + assert previous == rank + return owner_by_family + + +def _assert_remote_parent_state_transfers_cover(spec, plans) -> None: + owner_by_family = _local_owner_by_family(plans) + for family_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0 or parent_index not in owner_by_family: + continue + source_rank = owner_by_family[parent_index] + dest_rank = owner_by_family[family_index] + if source_rank == dest_rank: + continue + depth = spec.tree_depths[family_index] + source_exchange = plans[source_rank].tree_state_exchanges_by_depth[depth] + dest_exchange = plans[dest_rank].tree_state_exchanges_by_depth[depth] + assert source_exchange is not None + assert dest_exchange is not None + assert parent_index in source_exchange.source_family_indices + assert parent_index in dest_exchange.dest_family_indices + matching = [ + transfer + for transfer in dest_exchange.exchange.transfers + if transfer.source_rank == source_rank and transfer.dest_rank == dest_rank + ] + assert matching + + +def _remote_parent_state_transfer_count(plans) -> int: + return sum( + exchange.exchange.cross_rank_token_count + for plan in plans + for exchange in plan.tree_state_exchanges_by_depth + if exchange is not None + ) // len(plans) + + +def _tree_has_children(spec) -> list[bool]: + has_children = [False] * spec.family_count + for parent_index in spec.tree_parent_indices: + if parent_index >= 0: + has_children[parent_index] = True + return has_children + + +def _assert_tree_plan_health(spec, plans, *, max_padding_ratio: float) -> None: + tree_has_children = _tree_has_children(spec) + token_counts = [0] * int(spec.real_token_count) + for plan in plans: + range_tokens = sum( + end - start for start, end, _position in plan.gdn_token_ranges + ) + assert range_tokens == int(plan.gdn_token_count) + assert len(plan.attention_token_indices) == int(plan.attention_token_count) + + bucket_tokens = 0 + for depth_buckets in plan.tree_segment_buckets_by_depth: + for bucket in depth_buckets: + bucket_tokens += int(bucket.real_token_count) + assert bucket.parent_indices is not None + assert int(bucket.parent_indices.numel()) == int(bucket.segment_count) + assert int(bucket.real_token_count) > 0 + padding_ratio = ( + bucket.length * bucket.segment_count / bucket.real_token_count + ) + assert padding_ratio <= max_padding_ratio + bucket_state_flags = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + assert bucket_state_flags == {bucket.needs_final_state} + for family_index, parent_index in zip( + bucket.family_indices.tolist(), + bucket.parent_indices.tolist(), + strict=True, + ): + assert spec.tree_parent_indices[family_index] == parent_index + + for depth_buckets in plan.tree_chain_buckets_by_depth: + for bucket in depth_buckets: + bucket_tokens += int(bucket.real_token_count) + assert bucket.parent_indices is not None + assert int(bucket.parent_indices.numel()) == int(bucket.segment_count) + assert int(bucket.real_token_count) > 0 + padding_ratio = ( + bucket.length * bucket.segment_count / bucket.real_token_count + ) + assert padding_ratio <= max_padding_ratio + bucket_state_flags = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + if bucket.needs_final_state: + assert any(bucket_state_flags) + else: + assert bucket_state_flags == {False} + for family_index, parent_index in zip( + bucket.family_indices.tolist(), + bucket.parent_indices.tolist(), + strict=True, + ): + assert spec.tree_parent_indices[family_index] == parent_index + assert bucket_tokens == int(plan.gdn_token_count) + + for start, end, _position in plan.gdn_token_ranges: + for token_index in range(start, end): + token_counts[token_index] += 1 + + _assert_remote_parent_state_transfers_cover(spec, plans) + assert token_counts == [1] * int(spec.real_token_count) + rank_tokens = [int(plan.gdn_token_count) for plan in plans] + assert max(rank_tokens) - min(rank_tokens) <= max(256, spec.real_token_count // 3) + + +def _random_tree_sequences( + seed: int, *, max_depth: int = 4 +) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + next_token = 1 + + def tokens(length: int) -> torch.Tensor: + nonlocal next_token + out = torch.arange(next_token, next_token + length) + next_token += length + return out + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def walk(prefix: torch.Tensor, depth: int) -> list[torch.Tensor]: + segment_length = [1, 3, 17, 64, 129, 257][randint(0, 5)] + here = torch.cat((prefix, tokens(segment_length))) + if depth + 1 >= max_depth: + return [ + torch.cat((here, tokens(randint(1, 9)))) for _ in range(randint(2, 4)) + ] + leaves: list[torch.Tensor] = [] + for _ in range(randint(2, 3)): + leaves.extend(walk(here, depth + 1)) + return leaves + + return tuple(walk(torch.empty(0, dtype=torch.long), 0)) diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py new file mode 100644 index 000000000..005ea0757 --- /dev/null +++ b/tests/unit/test_trainer_rank_validation.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from art.megatron.trainer_rank import ( + ForwardInput, + TrainerRank, + Unset, + _validate_top_k, +) + + +class _Model: + vocab_size = 8 + + +def test_forward_input_rejects_non_positive_top_k() -> None: + with pytest.raises(ValueError, match="top_k must be >= 1"): + ForwardInput(input_tokens=torch.tensor([1]), top_k=0) + + +def test_forward_input_adapter_selection_defaults_to_unset() -> None: + request = ForwardInput(input_tokens=torch.tensor([1])) + + assert request.checkpoint is Unset + assert request.lora is Unset + + +def test_forward_input_accepts_explicit_base_checkpoint() -> None: + request = ForwardInput(input_tokens=torch.tensor([1]), checkpoint=None) + + assert request.checkpoint is None + assert request.lora is Unset + + +def test_forward_input_rejects_checkpoint_and_lora_together() -> None: + with pytest.raises(ValueError, match="cannot set both checkpoint and lora"): + ForwardInput(input_tokens=torch.tensor([1]), checkpoint="a", lora="b") + + +def test_validate_top_k_rejects_values_above_vocab_size() -> None: + with pytest.raises(ValueError, match="top_k=9 exceeds vocabulary size 8"): + _validate_top_k(9, _Model()) # type: ignore[arg-type] + + +def test_trainer_rank_accepts_nested_shared_prefix_for_gdn_runtime() -> None: + runtime = SimpleNamespace( + model=[torch.nn.Linear(1, 1)], + optimizer=None, + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + + trainer = TrainerRank(runtime, shared_prefix_max_depth=2) # type: ignore[arg-type] + + assert trainer.shared_prefix_max_depth == 2 + + +def test_trainer_rank_accepts_zero_depth_shared_prefix_for_gdn_runtime() -> None: + runtime = SimpleNamespace( + model=[torch.nn.Linear(1, 1)], + optimizer=None, + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + + trainer = TrainerRank(runtime, shared_prefix_max_depth=0) # type: ignore[arg-type] + + assert trainer.shared_prefix_max_depth == 0 + + +def test_trainer_rank_pop_rejects_empty_adapter_stack() -> None: + runtime = SimpleNamespace( + model=[torch.nn.Linear(1, 1)], + optimizer=None, + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + trainer = TrainerRank(runtime) # type: ignore[arg-type] + + with pytest.raises(RuntimeError, match="No pushed LoRA or checkpoint"): + trainer.pop_pushed_lora_or_checkpoint()