Refactor async HF metadata prewrite and background finalize#1858
Refactor async HF metadata prewrite and background finalize#1858caixianzhang wants to merge 19 commits into
Conversation
18d9e43 to
cded432
Compare
cded432 to
e00253f
Compare
…ackground-finalize
…ackground-finalize
…ackground-finalize
…ackground-finalize
There was a problem hiding this comment.
Delete this fake test and directly run a real asynchronous test at
xtuner/tests/train/test_trainer.py
Line 485 in d92529d
In addition to end-to-end correctness, we can use hooks to verify that in the asynchronous scenario, the hf and dcp saves at the N-th step are completed in a timely manner without lagging behind.
| self._terminate_failure(failure) | ||
|
|
||
| def _terminate_failure(self, failure: tuple[AsyncSaveWatchItem, BaseException]) -> None: | ||
| with self._lock: |
| if isinstance(hf_dir, str): | ||
| hf_dir = Path(hf_dir) | ||
|
|
||
| if self.config.hf_config is not None: | ||
| self.config.save_hf(hf_dir) | ||
| elif self._hf_path is not None: | ||
| for file in cast(Path, self._hf_path).iterdir(): | ||
| if file.suffix != ".safetensors": | ||
| target_path = hf_dir / file.name | ||
| if file.is_file(): | ||
| copy(file, target_path) | ||
| else: | ||
| copytree(file, target_path, ignore_dangling_symlinks=True, dirs_exist_ok=True) | ||
| else: | ||
| raise RuntimeError("Internal Error, both self.config.hf_config and self._hf_path are None") |
There was a problem hiding this comment.
Extract this part of the code into model base.py for shared use.
| def destroy_async_hf_resources(self) -> None: | ||
| super().destroy_async_hf_resources() | ||
| for module in (self.language_model, self.vision_tower, self.multi_modal_projector): | ||
| module._async_hf_tensor_cache.clear() |
There was a problem hiding this comment.
同理,在 base 里可以用 isinstance 去遍历所有的 BaseModel,去执行 _async_hf_tensor_cache
| with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"): | ||
| self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix) | ||
|
|
||
| def init_async_hf_resources(self) -> None: |
There was a problem hiding this comment.
Should not be a public interface
| self._barrier_async_hf() | ||
| log_rank0.info(f"[Async saving HF to {hf_dir}] finalized") | ||
| self._barrier_async_hf() |
| self._pending_async_hf_future: Future[Path] | None = None | ||
| self._pending_async_hf_finalize_done: threading.Event | None = None | ||
| self._async_save_monitor = AsyncSaveMonitor() | ||
| self._save_finalize_lock = threading.RLock() |
There was a problem hiding this comment.
Add an internal function to initialize asynchronous related resources
| if self._async_hf_export: | ||
| self._wait_for_pending_async_hf() | ||
| self._engine.model.destroy_async_hf_resources() | ||
|
|
||
| # TODO: Should use flush rather than close | ||
| self._wait_for_pending_checkpoint() | ||
| if self._async_hf_export or self._async_checkpoint: | ||
| self._async_save_monitor.stop() |
There was a problem hiding this comment.
The style is different. The hf part destroys within the if, while the dcp part destroys uniformly. Although both are acceptable, I prefer to destroy within the if
| log_rank0.info(f"Health check passed at step {self.cur_step}") | ||
|
|
||
| def _wait_for_pending_checkpoint(self, timeout: int = 3000) -> None: | ||
| def _wait_for_pending_checkpoint(self) -> None: |
There was a problem hiding this comment.
Where is the current timeout handling?
| current_exp = self.meta.latest_exp | ||
| ckp_list = current_exp.checkpoint_list if not is_snapshot else current_exp.snap_checkpoint_list | ||
| ckp_list.append(str(checkpoint_path)) | ||
| current_exp.cur_step = self.cur_step | ||
| current_exp.cur_epoch = self._cur_epoch | ||
| current_exp.consumed_tokens = int(total_consumed_tokens) | ||
| current_exp.history[-1]["end"] = self.cur_step | ||
|
|
||
| # Delete checkpoints and update meta's checkpoint_list | ||
| ckp_maxkeep = self._checkpoint_maxkeep if not is_snapshot else 1 | ||
| if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: | ||
| ckp_pop_num = len(ckp_list) - ckp_maxkeep | ||
| for _ in range(ckp_pop_num): | ||
| deleted_ckp = ckp_list.pop(0) | ||
| if self.rank == 0 and Path(deleted_ckp).exists(): | ||
| rmtree(deleted_ckp) |
There was a problem hiding this comment.
It needs to be executed via a callback, similar to finalize_hf. Otherwise, DCP may finish saving the meta information but the weights themselves are not fully saved, and if the process terminates midway, it will eventually lead to reading incorrect meta information upon resume.
|
@claude review |
| import os | ||
| import signal | ||
| import threading | ||
| from concurrent.futures import Future, ThreadPoolExecutor | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
|
|
||
| from xtuner.v1.utils import get_logger | ||
|
|
||
|
|
||
| logger = get_logger() | ||
|
|
||
|
|
||
| @dataclass | ||
| class AsyncSaveWatchItem: | ||
| name: str | ||
| future: Future | ||
| path: Path | ||
| step: int | ||
| epoch: int | None | ||
|
|
||
|
|
||
| class AsyncSaveMonitor: | ||
| def __init__(self, interval: float = 5.0): | ||
| self._items: list[AsyncSaveWatchItem] = [] | ||
| self._terminated = False | ||
| self._lock = threading.Lock() | ||
| self._stop_event = threading.Event() | ||
| self._interval = interval | ||
| self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="AsyncSaveMonitor") | ||
| self._monitor_future: Future | ||
|
|
||
| def start(self) -> None: | ||
| self._monitor_future = self._executor.submit(self._run) | ||
|
|
||
| def register(self, item: AsyncSaveWatchItem) -> None: | ||
| with self._lock: | ||
| self._items.append(item) | ||
|
|
||
| def stop(self) -> None: | ||
| self._stop_event.set() | ||
| self._monitor_future.result() | ||
| self._executor.shutdown(wait=True) | ||
|
|
||
| def _run(self) -> None: | ||
| while True: | ||
| stopped = self._stop_event.wait(self._interval) | ||
| if stopped: |
There was a problem hiding this comment.
Claude: The AsyncSaveMonitor class has multiple intermediate revisions visible in the PR (thread-based, then executor-based, then os.killpg-based). The final state in the squashed diff adds tests/utils/test_async_save_monitor.py which tests _check_watched_futures and os.killpg, but the final version of async_save_monitor.py itself uses _poll_once and never defines _check_watched_futures.
This appears to be a residual mismatch from the commit sequence where PATCH 10 renamed _poll_once to _check_watched_futures but PATCH 13 reverted the test file to the original test_trainer.py. Please verify the final test file (tests/utils/test_async_save_monitor.py) calls the correct method name that exists in the final AsyncSaveMonitor class.
| def get_async_save_file_lock_slots() -> int: | ||
| return max(0, int(os.environ.get("ASYNC_DCP_FILE_WRITE_LOCK_SLOTS", "1"))) | ||
| def get_async_hf_save_file_lock_slots() -> int: | ||
| return max(0, int(os.environ.get("ASYNC_HF_FILE_WRITE_LOCK_SLOTS", "1"))) |
There was a problem hiding this comment.
Claude: The env var name changed three times across this PR:
ASYNC_DCP_FILE_WRITE_LOCK_SLOTS(original)XTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTSwithXTUNER_HF_SAVE_WRITER_SAVE_FILE_LOCK_SLOTSfallback (PATCH 01)ASYNC_HF_FILE_WRITE_LOCK_SLOTS(PATCH 02, final)
The PR description says:
use async HF-specific file-write lock slots via
XTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTS, withXTUNER_HF_SAVE_WRITER_SAVE_FILE_LOCK_SLOTSretained only as an HF-specific compatibility alias
But the final code uses ASYNC_HF_FILE_WRITE_LOCK_SLOTS without the XTUNER_ prefix and without any compatibility alias. This contradicts the PR description. Which is intentional — the final env var name or the description?
| monitor.register(item) | ||
| with patch("xtuner.v1.utils.async_save_monitor.os.killpg") as killpg: | ||
| monitor._check_watched_futures() | ||
|
|
There was a problem hiding this comment.
Claude: Critical: Test calls _check_watched_futures() which doesn't exist in the final AsyncSaveMonitor.
In the final state (after PATCH 14/15 reverts back closer to the thread-based _poll_once implementation), the AsyncSaveMonitor class uses _poll_once(), not _check_watched_futures(). But this test file (added in PATCH 10) references _check_watched_futures, which was the name used only in the intermediate executor-based version.
This test will fail with AttributeError at runtime. The method calls should be monitor._poll_once() instead.
|
|
||
| monitor.register(item) | ||
| with ( | ||
| patch("xtuner.v1.utils.async_save_monitor.os.getpgrp", return_value=1234) as getpgrp, |
There was a problem hiding this comment.
Claude: Same issue as above — _check_watched_futures() does not exist in the final AsyncSaveMonitor class. Should be _poll_once().
| monitor._check_watched_futures() | ||
|
|
||
| getpgrp.assert_called_once_with() | ||
| killpg.assert_called_once_with(1234, signal.SIGTERM) |
There was a problem hiding this comment.
Claude: Correction to my earlier comments: I incorrectly stated that _check_watched_futures doesn't exist. PATCH 10 renames _poll_once → _check_watched_futures, so this test is correct. Please disregard my previous comments about this method name mismatch — apologies for the noise.
|
|
||
| def _get_file_write_lock_slots() -> int: | ||
| return max(0, int(os.environ.get("ASYNC_DCP_FILE_WRITE_LOCK_SLOTS", "0"))) | ||
| return max(0, int(os.environ.get("ASYNC_DCP_FILE_WRITE_LOCK_SLOTS", "1"))) |
There was a problem hiding this comment.
Claude: Behavioral change not mentioned in PR description. The default for ASYNC_DCP_FILE_WRITE_LOCK_SLOTS is changed from "0" (no file locking for DCP writes) to "1" (single concurrent writer slot). This affects DCP save I/O concurrency independently of the async HF refactoring. If intentional, worth noting in the PR description since it changes default behavior for existing users.
|
Claude: ## Review Summary This PR refactors async HF export to prewrite metadata before the safetensors writer starts and finalize via a background commit thread. The architecture is sound — separating the writer process (safetensors I/O) from the commit/finalize phase (all_gather status, rename) is a good improvement. IssuesWarning
Nit
VerdictCOMMENT — no blocking correctness bugs found, but the env var naming discrepancy and silent DCP default change should be confirmed as intentional before merge. |
Summary
This PR refactors async HF export so metadata is prepared before the safetensors writer starts, and the final checkpoint directory is published by a background finalize thread as soon as the writer completes.
It supersedes #1857 and carries over its fixes:
config.save_hf()/hf_config, matching sync HF behavior, and only fall back to copying from_hf_pathmodel.safetensors.index.jsonwith the same JSON formatting as compose sync HFMain changes:
async_hf_exportis enabledmodel.safetensors.index.jsoninto hidden.hf-*.incompletedirectories before launching the writer process.hf-*.incompletetohf-*once all ranks succeedXTUNER_ASYNC_HF_FILE_WRITE_LOCK_SLOTS, withXTUNER_HF_SAVE_WRITER_SAVE_FILE_LOCK_SLOTSretained only as an HF-specific compatibility alias; async HF no longer readsASYNC_DCP_FILE_WRITE_LOCK_SLOTSValidation
Passed:
python3 -m py_compile xtuner/v1/utils/process.py xtuner/v1/model/base.py xtuner/v1/model/compose/base.py xtuner/v1/engine/train_engine.py xtuner/v1/train/trainer.py tests/train/test_trainer_async_hf.pygit diff --checkNot run locally:
pytestis not installed in this environmentpython3 -m unittest tests.train.test_trainer_async_hf -vcannot import PyTorch internal distributed test helpers becauseexpecttestis missingCluster validation still recommended:
bash sft_files/run_qwen3_moe_30ba3_ep8_hf_mem_suite.shbash sft_files/run_qwen3_vl_moe_30ba3_ep8_hf_mem_suite.shbash sft_files/compare_hf_non_safetensors.shbash sft_files/check_hf_from_pretrained.sh