Skip to content

Refactor async HF metadata prewrite and background finalize#1858

Open
caixianzhang wants to merge 19 commits into
InternLM:mainfrom
caixianzhang:refactor/async-hf-background-finalize
Open

Refactor async HF metadata prewrite and background finalize#1858
caixianzhang wants to merge 19 commits into
InternLM:mainfrom
caixianzhang:refactor/async-hf-background-finalize

Conversation

@caixianzhang

@caixianzhang caixianzhang commented May 31, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR refactors async HF export so metadata is prepared before the safetensors writer starts, and the final checkpoint directory is published by a background finalize thread as soon as the writer completes.

It supersedes #1857 and carries over its fixes:

  • async HF non-weight files now prefer config.save_hf() / hf_config, matching sync HF behavior, and only fall back to copying from _hf_path
  • compose async HF writes model.safetensors.index.json with the same JSON formatting as compose sync HF

Main changes:

  • add an async HF runtime initialized before training starts when async_hf_export is enabled
  • create a dedicated Gloo process group for async HF control-plane collectives
  • prewrite config/tokenizer/processor files and model.safetensors.index.json into hidden .hf-*.incomplete directories before launching the writer process
  • keep the writer process focused on safetensors shard I/O and per-rank status writing
  • add a single-worker background finalize executor that joins the writer, gathers per-rank status, and renames .hf-*.incomplete to hf-* once all ranks succeed
  • record background finalize failures via future callback and let Trainer fail fast at the iteration control point instead of waiting for the next save
  • keep async HF runtime initialization fail-fast: initialization errors abort training before the first step
  • use async HF-specific file-write lock slots via XTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTS, with XTUNER_HF_SAVE_WRITER_SAVE_FILE_LOCK_SLOTS retained only as an HF-specific compatibility alias; async HF no longer reads ASYNC_DCP_FILE_WRITE_LOCK_SLOTS

Validation

Passed:

  • python3 -m py_compile xtuner/v1/utils/process.py xtuner/v1/model/base.py xtuner/v1/model/compose/base.py xtuner/v1/engine/train_engine.py xtuner/v1/train/trainer.py tests/train/test_trainer_async_hf.py
  • git diff --check

Not run locally:

  • pytest is not installed in this environment
  • python3 -m unittest tests.train.test_trainer_async_hf -v cannot import PyTorch internal distributed test helpers because expecttest is missing

Cluster validation still recommended:

  • bash sft_files/run_qwen3_moe_30ba3_ep8_hf_mem_suite.sh
  • bash sft_files/run_qwen3_vl_moe_30ba3_ep8_hf_mem_suite.sh
  • bash sft_files/compare_hf_non_safetensors.sh
  • bash sft_files/check_hf_from_pretrained.sh

@caixianzhang caixianzhang force-pushed the refactor/async-hf-background-finalize branch 2 times, most recently from 18d9e43 to cded432 Compare May 31, 2026 03:33
@caixianzhang caixianzhang force-pushed the refactor/async-hf-background-finalize branch from cded432 to e00253f Compare May 31, 2026 03:43

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this fake test and directly run a real asynchronous test at

return TrainerConfig(
.

In addition to end-to-end correctness, we can use hooks to verify that in the asynchronous scenario, the hf and dcp saves at the N-th step are completed in a timely manner without lagging behind.

self._terminate_failure(failure)

def _terminate_failure(self, failure: tuple[AsyncSaveWatchItem, BaseException]) -> None:
with self._lock:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary locks?

Comment on lines +177 to +191
if isinstance(hf_dir, str):
hf_dir = Path(hf_dir)

if self.config.hf_config is not None:
self.config.save_hf(hf_dir)
elif self._hf_path is not None:
for file in cast(Path, self._hf_path).iterdir():
if file.suffix != ".safetensors":
target_path = hf_dir / file.name
if file.is_file():
copy(file, target_path)
else:
copytree(file, target_path, ignore_dangling_symlinks=True, dirs_exist_ok=True)
else:
raise RuntimeError("Internal Error, both self.config.hf_config and self._hf_path are None")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract this part of the code into model base.py for shared use.

Comment on lines +196 to +199
def destroy_async_hf_resources(self) -> None:
super().destroy_async_hf_resources()
for module in (self.language_model, self.vision_tower, self.multi_modal_projector):
module._async_hf_tensor_cache.clear()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,在 base 里可以用 isinstance 去遍历所有的 BaseModel,去执行 _async_hf_tensor_cache

Comment thread xtuner/v1/model/base.py
with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"):
self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix)

def init_async_hf_resources(self) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be a public interface

Comment thread xtuner/v1/model/base.py
Comment on lines +899 to +901
self._barrier_async_hf()
log_rank0.info(f"[Async saving HF to {hf_dir}] finalized")
self._barrier_async_hf()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too much barrier

Comment on lines +609 to +612
self._pending_async_hf_future: Future[Path] | None = None
self._pending_async_hf_finalize_done: threading.Event | None = None
self._async_save_monitor = AsyncSaveMonitor()
self._save_finalize_lock = threading.RLock()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an internal function to initialize asynchronous related resources

Comment on lines +893 to +900
if self._async_hf_export:
self._wait_for_pending_async_hf()
self._engine.model.destroy_async_hf_resources()

# TODO: Should use flush rather than close
self._wait_for_pending_checkpoint()
if self._async_hf_export or self._async_checkpoint:
self._async_save_monitor.stop()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The style is different. The hf part destroys within the if, while the dcp part destroys uniformly. Although both are acceptable, I prefer to destroy within the if

log_rank0.info(f"Health check passed at step {self.cur_step}")

def _wait_for_pending_checkpoint(self, timeout: int = 3000) -> None:
def _wait_for_pending_checkpoint(self) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the current timeout handling?

Comment on lines +1297 to +1312
current_exp = self.meta.latest_exp
ckp_list = current_exp.checkpoint_list if not is_snapshot else current_exp.snap_checkpoint_list
ckp_list.append(str(checkpoint_path))
current_exp.cur_step = self.cur_step
current_exp.cur_epoch = self._cur_epoch
current_exp.consumed_tokens = int(total_consumed_tokens)
current_exp.history[-1]["end"] = self.cur_step

# Delete checkpoints and update meta's checkpoint_list
ckp_maxkeep = self._checkpoint_maxkeep if not is_snapshot else 1
if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep:
ckp_pop_num = len(ckp_list) - ckp_maxkeep
for _ in range(ckp_pop_num):
deleted_ckp = ckp_list.pop(0)
if self.rank == 0 and Path(deleted_ckp).exists():
rmtree(deleted_ckp)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to be executed via a callback, similar to finalize_hf. Otherwise, DCP may finish saving the meta information but the weights themselves are not fully saved, and if the process terminates midway, it will eventually lead to reading incorrect meta information upon resume.

@HAOCHENYE

Copy link
Copy Markdown
Collaborator

@claude review

Comment on lines +1 to +48
import os
import signal
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path

from xtuner.v1.utils import get_logger


logger = get_logger()


@dataclass
class AsyncSaveWatchItem:
name: str
future: Future
path: Path
step: int
epoch: int | None


class AsyncSaveMonitor:
def __init__(self, interval: float = 5.0):
self._items: list[AsyncSaveWatchItem] = []
self._terminated = False
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._interval = interval
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="AsyncSaveMonitor")
self._monitor_future: Future

def start(self) -> None:
self._monitor_future = self._executor.submit(self._run)

def register(self, item: AsyncSaveWatchItem) -> None:
with self._lock:
self._items.append(item)

def stop(self) -> None:
self._stop_event.set()
self._monitor_future.result()
self._executor.shutdown(wait=True)

def _run(self) -> None:
while True:
stopped = self._stop_event.wait(self._interval)
if stopped:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: The AsyncSaveMonitor class has multiple intermediate revisions visible in the PR (thread-based, then executor-based, then os.killpg-based). The final state in the squashed diff adds tests/utils/test_async_save_monitor.py which tests _check_watched_futures and os.killpg, but the final version of async_save_monitor.py itself uses _poll_once and never defines _check_watched_futures.

This appears to be a residual mismatch from the commit sequence where PATCH 10 renamed _poll_once to _check_watched_futures but PATCH 13 reverted the test file to the original test_trainer.py. Please verify the final test file (tests/utils/test_async_save_monitor.py) calls the correct method name that exists in the final AsyncSaveMonitor class.

def get_async_save_file_lock_slots() -> int:
return max(0, int(os.environ.get("ASYNC_DCP_FILE_WRITE_LOCK_SLOTS", "1")))
def get_async_hf_save_file_lock_slots() -> int:
return max(0, int(os.environ.get("ASYNC_HF_FILE_WRITE_LOCK_SLOTS", "1")))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: The env var name changed three times across this PR:

  1. ASYNC_DCP_FILE_WRITE_LOCK_SLOTS (original)
  2. XTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTS with XTUNER_HF_SAVE_WRITER_SAVE_FILE_LOCK_SLOTS fallback (PATCH 01)
  3. ASYNC_HF_FILE_WRITE_LOCK_SLOTS (PATCH 02, final)

The PR description says:

use async HF-specific file-write lock slots via XTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTS, with XTUNER_HF_SAVE_WRITER_SAVE_FILE_LOCK_SLOTS retained only as an HF-specific compatibility alias

But the final code uses ASYNC_HF_FILE_WRITE_LOCK_SLOTS without the XTUNER_ prefix and without any compatibility alias. This contradicts the PR description. Which is intentional — the final env var name or the description?

monitor.register(item)
with patch("xtuner.v1.utils.async_save_monitor.os.killpg") as killpg:
monitor._check_watched_futures()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: Test calls _check_watched_futures() which doesn't exist in the final AsyncSaveMonitor.

In the final state (after PATCH 14/15 reverts back closer to the thread-based _poll_once implementation), the AsyncSaveMonitor class uses _poll_once(), not _check_watched_futures(). But this test file (added in PATCH 10) references _check_watched_futures, which was the name used only in the intermediate executor-based version.

This test will fail with AttributeError at runtime. The method calls should be monitor._poll_once() instead.


monitor.register(item)
with (
patch("xtuner.v1.utils.async_save_monitor.os.getpgrp", return_value=1234) as getpgrp,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Same issue as above — _check_watched_futures() does not exist in the final AsyncSaveMonitor class. Should be _poll_once().

monitor._check_watched_futures()

getpgrp.assert_called_once_with()
killpg.assert_called_once_with(1234, signal.SIGTERM)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Correction to my earlier comments: I incorrectly stated that _check_watched_futures doesn't exist. PATCH 10 renames _poll_once_check_watched_futures, so this test is correct. Please disregard my previous comments about this method name mismatch — apologies for the noise.


def _get_file_write_lock_slots() -> int:
return max(0, int(os.environ.get("ASYNC_DCP_FILE_WRITE_LOCK_SLOTS", "0")))
return max(0, int(os.environ.get("ASYNC_DCP_FILE_WRITE_LOCK_SLOTS", "1")))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Behavioral change not mentioned in PR description. The default for ASYNC_DCP_FILE_WRITE_LOCK_SLOTS is changed from "0" (no file locking for DCP writes) to "1" (single concurrent writer slot). This affects DCP save I/O concurrency independently of the async HF refactoring. If intentional, worth noting in the PR description since it changes default behavior for existing users.

@claude

claude Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Claude: ## Review Summary

This PR refactors async HF export to prewrite metadata before the safetensors writer starts and finalize via a background commit thread. The architecture is sound — separating the writer process (safetensors I/O) from the commit/finalize phase (all_gather status, rename) is a good improvement.

Issues

Warning

  • Env var name mismatch with PR description xtuner/v1/utils/process.py:30: PR description says XTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTS with a compat alias, but final code uses ASYNC_HF_FILE_WRITE_LOCK_SLOTS with no alias. Clarify which is intended.
  • DCP lock slot default silently changed xtuner/v1/patch/xtuner_storage.py:73: ASYNC_DCP_FILE_WRITE_LOCK_SLOTS default changed from "0" to "1". This is a behavioral change to DCP save unrelated to async HF — worth calling out explicitly.
  • os.killpg(SIGTERM) as the sole failure signal: When AsyncSaveMonitor detects a failed future, it terminates the process group. This means async save failures between HF save steps cannot be caught and reported gracefully by the training loop. The earlier raise_if_failed() call in the loop (PATCH 08) was removed. If SIGTERM is the intentional design, consider documenting this as a conscious tradeoff.

Nit

  • Correction: My initial inline comments about _check_watched_futures vs _poll_once were incorrect — please disregard. PATCH 10 correctly renames the method and the tests call the right name.
  • The PR carries 15 commits with several back-and-forth refactors (warmup added then removed, file_finalize_callback added then removed, executor-based monitor then reverted). Consider squashing before merge for a cleaner history.

Verdict

COMMENT — no blocking correctness bugs found, but the env var naming discrepancy and silent DCP default change should be confirmed as intentional before merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants