Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/build-gpu-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ on:
required: true
default: true
type: boolean
prewarm_modal:
description: "Prebuild the pushed image in Modal when auth is configured"
required: true
default: true
type: boolean
prewarm_timeout:
description: "Timeout for GPU node prewarm rollout"
required: true
Expand Down Expand Up @@ -155,11 +160,16 @@ jobs:
PULL_IMAGE_REPO: ${{ inputs.pull_image_repo || 'docker.io/bradhiltonnw/art-gpu' }}
IMAGE_TAG: ${{ inputs.tag }}
NO_CACHE: ${{ inputs.no_cache }}
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
PREWARM_MODAL_INPUT: ${{ inputs.prewarm_modal }}
PREWARM_NODES: ${{ inputs.prewarm_nodes }}
PREWARM_TIMEOUT: ${{ inputs.prewarm_timeout }}
run: |
IMAGE_TAG="${IMAGE_TAG:-latest}"
NO_CACHE="${NO_CACHE:-false}"
export PREWARM_MODAL="${PREWARM_MODAL:-auto}"
PREWARM_MODAL_INPUT="${PREWARM_MODAL_INPUT:-true}"
PREWARM_NODES="${PREWARM_NODES:-true}"
PREWARM_TIMEOUT="${PREWARM_TIMEOUT:-30m}"

Expand All @@ -175,6 +185,10 @@ jobs:
args+=(--no-cache)
fi

if [ "${PREWARM_MODAL_INPUT}" = "false" ]; then
args+=(--no-prewarm-modal)
fi

if [ "${PREWARM_NODES}" != "true" ]; then
args+=(--no-prewarm-nodes)
fi
Expand Down
116 changes: 116 additions & 0 deletions dev/trainer_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

import os

import torch
import torch.distributed as dist
from transformers import AutoTokenizer
import typer

from art.megatron.trainer_rank import AdamParams, ForwardInput, TrainerRank


def main(
model: str = "Qwen/Qwen3-0.6B",
dataset: str = "roneneldan/TinyStories",
split: str = "train",
text_column: str = "text",
samples: int = 16,
steps: int = 1,
micro_batch_size: int = 1,
lr: float = 5e-5,
layers: int = 2,
max_seq_length: int = 256,
) -> None:
os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1")
os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1")
os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1")

if not torch.cuda.is_available():
raise RuntimeError("dev/trainer_rank.py requires CUDA")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group(backend="nccl")

try:
from datasets import load_dataset

from art.megatron import train as megatron_train

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
inputs: list[ForwardInput[torch.Tensor, None, None, None]] = []
for row in load_dataset(dataset, split=split, streaming=True):
text = str(row.get(text_column, "")).strip() # type: ignore[union-attr]
if not text:
continue
token_ids = tokenizer(
text,
add_special_tokens=True,
truncation=True,
max_length=max_seq_length + 1,
return_tensors="pt",
)["input_ids"].reshape(-1)
if int(token_ids.numel()) <= 1:
continue
inputs.append(
ForwardInput(
input_tokens=token_ids[:-1],
target_tokens=token_ids[1:],
)
)
if len(inputs) >= samples:
break
if not inputs:
raise RuntimeError("dataset produced no tokenized training examples")

runtime = megatron_train.build_training_runtime(
model_identifier=model,
provider_configure=lambda provider: setattr(
provider,
"num_layers",
layers,
),
print_env=dist.get_rank() == 0,
)
rank = TrainerRank(runtime, micro_batch_size=micro_batch_size)
if dist.get_rank() == 0:
print(
"TrainerRank ready: "
f"dp={megatron_train.ps.get_data_parallel_world_size()} "
f"device={rank.device}",
flush=True,
)

for step in range(steps):
loss_sum = torch.tensor(0.0, device=rank.device)
token_count = torch.tensor(0.0, device=rank.device)
for micro in rank.micro_batches(inputs):
outputs = rank.forward(micro.inputs)
loss = torch.tensor(0.0, device=rank.device)
for output in outputs:
assert output.target_logprobs is not None
loss = loss - output.target_logprobs.sum()
token_count += output.target_logprobs.numel()
if loss.requires_grad:
loss.backward()
loss_sum += loss.detach()

rank.dp_reduce(loss_sum)
rank.dp_reduce(token_count)
scale = 1.0 / max(float(token_count.item()), 1.0)
metrics = rank.optim_step(
params=AdamParams(learning_rate=lr),
scale_grads=scale,
)
metrics["loss"] = float(loss_sum.item() * scale)
metrics["tokens"] = float(token_count.item())
if dist.get_rank() == 0:
print(f"step={step} {metrics}", flush=True)

dist.barrier()
finally:
if dist.is_initialized():
dist.destroy_process_group()


if __name__ == "__main__":
typer.run(main)
Loading
Loading