Skip to content

fix: Diverse model seeding across PP ranks#426

Open
rrutmann wants to merge 19 commits intomainfrom
seed
Open

fix: Diverse model seeding across PP ranks#426
rrutmann wants to merge 19 commits intomainfrom
seed

Conversation

@rrutmann
Copy link
Copy Markdown
Collaborator

@rrutmann rrutmann commented Dec 10, 2025

What does this PR do?

This PR gives a unique model seed for each pp rank, such that parameters are initialized differently across ranks.

General Changes

  • On each rank, add the pp rank to the model seed.

Breaking Changes

  • None

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Copy Markdown
Member

@BlueCrescent BlueCrescent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.
Should we also explicitly allow seeding for the "model_initialized" component?
It will probably inherit the random state from the model_raw component but it seems a bit risky to me to assume that (also in the future) no other interaction with the random state happens between these two components (though, probably, only interactions that are asymmetrical between the ranks would be problematic). In particular, since we cannot guarantee the order in which the components are build, something like a dataloader component might even re-seed the random state.

Comment thread tests/fsdp2_parallelization/test_parallel_seed_initialization.py Outdated
@rrutmann rrutmann requested a review from le1nux December 19, 2025 10:25
@rrutmann rrutmann self-assigned this Dec 19, 2025
Copy link
Copy Markdown
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the seeding (not the test) and from my understanding the changes do not provide the expected results (also what @BlueCrescent was hinting towards).

When we seed the raw model, the model weights are indeed deterministic at instantiation time. However, we also have model weight initialization which runs afterwards and would just override the weights / seeding.

Additionally, passing device_mesh to the model is coupling two components that should normally not know anything about each other.

I think we have to integrate the seeding to the weight initializer component and can also pass in the device_mesh there.

@rrutmann
Copy link
Copy Markdown
Collaborator Author

I checked the seeding (not the test) and from my understanding the changes do not provide the expected results (also what @BlueCrescent was hinting towards).

When we seed the raw model, the model weights are indeed deterministic at instantiation time. However, we also have model weight initialization which runs afterwards and would just override the weights / seeding.

Additionally, passing device_mesh to the model is coupling two components that should normally not know anything about each other.

I think we have to integrate the seeding to the weight initializer component and can also pass in the device_mesh there.

Yes that makes sense. I moved the seeding to the model initialization component

@rrutmann
Copy link
Copy Markdown
Collaborator Author

Overall LGTM. Should we also explicitly allow seeding for the "model_initialized" component? It will probably inherit the random state from the model_raw component but it seems a bit risky to me to assume that (also in the future) no other interaction with the random state happens between these two components (though, probably, only interactions that are asymmetrical between the ranks would be problematic). In particular, since we cannot guarantee the order in which the components are build, something like a dataloader component might even re-seed the random state.

See #426 (comment)

Copy link
Copy Markdown
Member

@BlueCrescent BlueCrescent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally good state.
Left a couple of comments.
My main concern is the global setting of the seed. A generator object might be favorable.

Comment thread src/modalities/models/model.py Outdated
"""NNModel class to define a base model."""

def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
def __init__(self, seed: int | None = None, weight_decay_groups: Optional[WeightDecayGroups] = None):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even want to allow setting the seed here?
Could torch.manual_seed below have side effects with the new weight init implementation?

Copy link
Copy Markdown
Collaborator Author

@rrutmann rrutmann May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably it could have side effects, e.g. default submodule initialization, random ops and the ambient global RNG state for unrelated code. Also it is mostly redundant since we now use a local generator for weight initialization. I would suggest to remove it.

Comment thread src/modalities/nn/model_initialization/composed_initialization.py Outdated
return initialization

@staticmethod
def _set_seed(seed: Optional[int]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sets the seed globally. I think an even more robust way would be to use a local torch rng object.
Could this be integrated?

Something like:

g = torch.Generator()
g.manual_seed(1234)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, we might get into side-effects later on

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I integrated this

if seed is not None and has_parallelism_method(
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
):
seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also means that depending on the parallelization method and also the number of parallelism degrees we get differently initialized layers even if the seed is the same.
Example:

DP with seed = 1, will have a differently initilized model than DP+PP with seed = 1.

One way to fix this is to always use the same seed but each PP stage has to skip the number of random values of the pervious stages.
However, I think this would be overkill and I would just place a warning when initialising the weights and parallelization methods are other than FSDP.
What do you think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also give the hint that for full reproducibility a Distributed Checkpoint with FSDP directly after weight init.
Maybe we could even have an entry point for that in main.

something like:
modalities create_init_cp model_config.yaml

For some unit tests, this functionality would be nice to have anyways I think.

Any thoughts?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I added a warning for now, but the additional entry point would be nice to have as well. I created an issue for that

Comment thread src/modalities/nn/model_initialization/initialization_routines.py Outdated
rrutmann and others added 8 commits May 5, 2026 14:00
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
…tialization

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
@le1nux le1nux self-requested a review May 6, 2026 17:10
Copy link
Copy Markdown
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Nice work :)
The comment regarding the per-device Generator we should discuss, what makes most sense here.

I would add one last test, which checks that two models instantiated with the same config file (with a specified seed), should have 100% matching parameter weights. I'd keep that one simple (no advanced sharding like TP or PP. only FSDP).

from transformers.utils.generic import check_model_inputs
except ImportError:

def check_model_inputs(func: Callable) -> Callable:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this removed in transormers?
If it is part of a legacy API I think we should also remove this on our end.
What do you think @BlueCrescent? I think you added it, right?

self.seed = torch.initial_seed() if seed is None else seed
self._generators: dict[str, torch.Generator] = {}

def _get_generator(self, parameter: torch.Tensor) -> torch.Generator:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few things are not clear to me.

  1. Do we actually have the case, where in a single process tensors are sitting on different GPUs?
  2. if 1. is the case, then we can end up with tensors that are initialized identically, since we create multiple generators from the same seed.

I'm not sure what the best way to solve this ... also seems to me that the Pytorch API regarding Generators is kinda limited.

std (float): standard deviation of the normal distribution. If set to "auto", appropiate
value selected as per plain initialization described in https://arxiv.org/abs/2312.16903
hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None.
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type annotation is wrong. should be RegexFilter

@@ -99,6 +118,7 @@ def get_scaled_initialization(
num_layers (int): Number of layers in the model which we use to downscale std with
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
parameter_name_regexes (RegexFilter): List of parameter name regexes to which the initialization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants