diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..e306aef570 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2886,6 +2886,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.") if self.elastic_enabled and not self.enable_single_controller: raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).") + if self.colocated_python_data_input and not self.enable_single_controller: + raise ValueError( + "Colocated python data input is only supported with Pathways (single" + " controller) enabled (`enable_single_controller=True`)." + ) if self.grain_use_elastic_iterator and self.grain_file_type != "arrayrecord": raise ValueError( "`grain_use_elastic_iterator=True` only supports `grain_file_type=arrayrecord`. " diff --git a/src/maxtext/input_pipeline/multihost_dataloading.py b/src/maxtext/input_pipeline/multihost_dataloading.py index 221a6ed338..9ae3d9ca34 100644 --- a/src/maxtext/input_pipeline/multihost_dataloading.py +++ b/src/maxtext/input_pipeline/multihost_dataloading.py @@ -264,8 +264,7 @@ class RemoteIteratorWrapper: """Wrapper for RemoteIterator that handles device placement.""" def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path="", elastic=False): - self.cpu_devices = _colocated_cpu_devices(jax.local_devices()) - self.tpu_devices = jax.local_devices() + self.cpu_devices = _colocated_cpu_devices(tuple(global_mesh.devices.flat)) self.cpu_mesh = _colocated_cpu_mesh(global_mesh) self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names)) @@ -288,11 +287,13 @@ def __next__(self): return jax.device_put(out, self.tpu_sharding) def save_state(self, step): - step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) - step_array = jax.device_put(step_array, self.cpu_sharding) + replicated_cpu_sharding = NamedSharding(self.cpu_mesh, PartitionSpec()) + step_array = jnp.array(step, dtype=jnp.int32) + step_array = jax.device_put(step_array, replicated_cpu_sharding) self.local_iterator.save_state(step_array) def restore_state(self, step): - step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) - step_array = jax.device_put(step_array, self.cpu_sharding) + replicated_cpu_sharding = NamedSharding(self.cpu_mesh, PartitionSpec()) + step_array = jnp.array(step, dtype=jnp.int32) + step_array = jax.device_put(step_array, replicated_cpu_sharding) self.local_iterator.restore_state(step_array) diff --git a/tests/unit/multihost_dataloading_test.py b/tests/unit/multihost_dataloading_test.py index d4a6172141..2324b6c8c9 100644 --- a/tests/unit/multihost_dataloading_test.py +++ b/tests/unit/multihost_dataloading_test.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,29 +12,89 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=missing-module-docstring, missing-function-docstring +# pylint: disable=missing-module-docstring, missing-function-docstring, line-too-long, g-generic-assert import itertools +import json + +import pathlib import sys -import unittest +import tempfile -import pytest -import numpy as np +from unittest import mock +from absl.testing import absltest +from absl.testing import parameterized import jax -from jax.sharding import Mesh from jax.experimental import mesh_utils - +import jax.experimental.colocated_python +from jax.sharding import Mesh from maxtext.configs import pyconfig +from tests.utils.test_helpers import get_test_base_output_directory +from tests.utils.test_helpers import get_test_config_path +from tests.utils.test_helpers import get_test_dataset_path +import numpy as np +import pytest + +# Mock jax.experimental.colocated_python before it is imported by +# multihost_dataloading +mock.patch.object( + jax.experimental.colocated_python, + "colocated_python_class", + lambda cls: cls, +).start() +mock.patch.object( + jax.experimental.colocated_python, + "colocated_cpu_devices", + lambda x: x, +).start() + +# pylint: disable=g-import-not-at-top, g-bad-import-order from maxtext.input_pipeline import multihost_dataloading -from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory +# pylint: enable=g-import-not-at-top, g-bad-import-order + + +class MockIterator: + """Mock iterator for testing dataloading state saving/restoring.""" + + def __init__(self, mesh_size): + self.state = 0 + self.mesh_size = mesh_size + + def __next__(self): + self.state += 1 + return np.full((self.mesh_size, 1), self.state, dtype=np.int32) + + def get_state(self) -> dict[str, int]: + return {"state": self.state} + + def set_state(self, state: dict[str, int]): + self.state = state["state"] + + +class MockDataloader: + """Mock dataloader for testing.""" + + def __init__(self, mesh_size): + self.mesh_size = mesh_size + + def __iter__(self) -> MockIterator: + return MockIterator(self.mesh_size) -class MultihostDataloadingTest(unittest.TestCase): +def _get_test_mesh_shapes_named(): + return [ + ("1_device", (1, 1)), + ("2_devices", (2, 1)), + ("4_devices", (2, 2)), + ] + + +class MultihostDataloadingTest(parameterized.TestCase): def setUp(self): super().setUp() - # Note: this test uses gs://max-experiments/ (not gs://runner-maxtext-logs) in cloud mode + # Note: this test uses gs://max-experiments/ (not runner logs) in cloud mode base_output_directory = get_test_base_output_directory(cloud_path="gs://max-experiments/") dataset_path = get_test_dataset_path(cloud_path="gs://maxtext-dataset/") batch_size = len(jax.devices()) @@ -62,8 +122,94 @@ def setUp(self): def test_batch_sharded_data_pipeline(self): first_batch = next(self.multihost_gen) sec_batch = next(self.multihost_gen) - self.assertTrue(not np.array_equal(first_batch, sec_batch, equal_nan=True)) + self.assertFalse(np.array_equal(first_batch, sec_batch, equal_nan=True)) + + @parameterized.named_parameters(*_get_test_mesh_shapes_named()) + def test_remote_iterator_wrapper_save_state(self, mesh_shape): + mesh_size = mesh_shape[0] * mesh_shape[1] + if mesh_size > len(jax.devices()): + self.skipTest( + f"Skipping test because available devices ({len(jax.devices())}) is" + f" less than required mesh size ({mesh_size}) for shape {mesh_shape}." + ) + + devs = jax.devices()[:mesh_size] + devices = mesh_utils.create_device_mesh(mesh_shape, devs) + mesh = Mesh(devices, ("x", "y")) + + def get_ds_fn(dataloading_host_index, dataloading_host_count): + del dataloading_host_index, dataloading_host_count + return MockDataloader(mesh_size) + + def preprocessing_fn(dataset): + return dataset + + global_shape = (mesh_size, 1) + + with tempfile.TemporaryDirectory() as tmpdir: + wrapper = multihost_dataloading.RemoteIteratorWrapper( + get_ds_fn=get_ds_fn, + preprocessing_fn=preprocessing_fn, + global_mesh=mesh, + global_shape=global_shape, + checkpoint_path=tmpdir, + elastic=False, + ) + # Advance state once so the value is 1 + next(wrapper) + + wrapper.save_state(step=5) + + # Verify that a file was written in the tempdir containing {"state": 1} + json_files = list(pathlib.Path(tmpdir).glob("**/*.json")) + self.assertEqual(len(json_files), 1, f"Expected 1 JSON file, found: {json_files}") + written_content = json_files[0].read_text() + self.assertEqual(json.loads(written_content), {"state": 1}) + + @parameterized.named_parameters(*_get_test_mesh_shapes_named()) + def test_remote_iterator_wrapper_restore_state(self, mesh_shape): + mesh_size = mesh_shape[0] * mesh_shape[1] + if mesh_size > len(jax.devices()): + self.skipTest( + f"Skipping test because available devices ({len(jax.devices())}) is" + f" less than required mesh size ({mesh_size}) for shape {mesh_shape}." + ) + + devs = jax.devices()[:mesh_size] + devices = mesh_utils.create_device_mesh(mesh_shape, devs) + mesh = Mesh(devices, ("x", "y")) + + def get_ds_fn(dataloading_host_index, dataloading_host_count): + del dataloading_host_index, dataloading_host_count + return MockDataloader(mesh_size) + + def preprocessing_fn(dataset): + return dataset + + global_shape = (mesh_size, 1) + + with tempfile.TemporaryDirectory() as tmpdir: + step = 5 + state_dir = pathlib.Path(tmpdir) / str(step) / "iter" + state_dir.mkdir(parents=True, exist_ok=True) + state_file = state_dir / "process_0-of-1.json" + state_file.write_text('{"state": 10}') + + wrapper = multihost_dataloading.RemoteIteratorWrapper( + get_ds_fn=get_ds_fn, + preprocessing_fn=preprocessing_fn, + global_mesh=mesh, + global_shape=global_shape, + checkpoint_path=tmpdir, + elastic=False, + ) + + wrapper.restore_state(step=5) + val = next(wrapper) + + # Next value should be 11 (state 10 + 1) + self.assertEqual(val.addressable_data(0)[0], 11) if __name__ == "__main__": - unittest.main() + absltest.main()