Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,20 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False.
sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None.
with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an
external context tensor. When False, cross_attn is set to nn.Identity() so that the attribute
always exists for typing and checkpoint compatibility. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.

Raises:
ValueError: if dropout_rate is not in [0, 1].
ValueError: if hidden_size is not divisible by num_heads.

"""

super().__init__()
Expand Down Expand Up @@ -79,14 +88,18 @@ def __init__(
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
self.cross_attn: CrossAttentionBlock | nn.Identity
if with_cross_attention:
Comment thread
ericspod marked this conversation as resolved.
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
else:
self.cross_attn = nn.Identity()

def forward(
self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None
Expand Down
38 changes: 38 additions & 0 deletions monai/networks/nets/masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,41 @@ def forward(self, x, masking_ratio: float | None = None):

x = x[:, 1:, :]
return x, mask

def load_old_state_dict(self, old_state_dict: dict, verbose: bool = False) -> None:
"""
Load a state dict from a MaskedAutoEncoderViT model trained with an older version of MONAI
where ``CrossAttentionBlock`` was unconditionally instantiated in ``TransformerBlock``
even when ``with_cross_attention=False``. Old checkpoints contain stale
``blocks.{i}.cross_attn.*`` and ``decoder_blocks.{i}.cross_attn.*`` keys that are not
present in the current model and are automatically dropped.

Args:
old_state_dict: state dict from the older MaskedAutoEncoderViT model.
verbose: if True, print keys that are missing or unmatched. Defaults to False.
"""
new_state_dict = self.state_dict()
if all(k in new_state_dict for k in old_state_dict):
if verbose:
print("All keys match, loading state dict.")
self.load_state_dict(old_state_dict)
return

if verbose:
for k in new_state_dict:
if k not in old_state_dict:
print(f"key {k} not found in old state dict")
print("----------------------------------------------")
for k in old_state_dict:
if k not in new_state_dict:
print(f"key {k} not found in new state dict")

# copy over all matching keys; stale cross_attn.* keys in blocks and decoder_blocks
# are left as unmatched leftovers and are not inserted into new_state_dict
for k in new_state_dict:
if k in old_state_dict:
new_state_dict[k] = old_state_dict.pop(k)

if verbose:
print("remaining keys in old_state_dict:", old_state_dict.keys())
self.load_state_dict(new_state_dict)
38 changes: 38 additions & 0 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,41 @@ def forward(self, x):
if hasattr(self, "classification_head"):
x = self.classification_head(x[:, 0])
return x, hidden_states_out

def load_old_state_dict(self, old_state_dict: dict, verbose: bool = False) -> None:
"""
Load a state dict from a ViT model trained with an older version of MONAI where
``CrossAttentionBlock`` was unconditionally instantiated in ``TransformerBlock``
even when ``with_cross_attention=False``. Old checkpoints contain stale
``blocks.{i}.cross_attn.*`` keys that are not present in the current model and
are automatically dropped.

Args:
old_state_dict: state dict from the older ViT model.
verbose: if True, print keys that are missing or unmatched. Defaults to False.
"""
new_state_dict = self.state_dict()
if all(k in new_state_dict for k in old_state_dict):
if verbose:
print("All keys match, loading state dict.")
self.load_state_dict(old_state_dict)
return

if verbose:
for k in new_state_dict:
if k not in old_state_dict:
print(f"key {k} not found in old state dict")
print("----------------------------------------------")
for k in old_state_dict:
if k not in new_state_dict:
print(f"key {k} not found in new state dict")

# copy over all matching keys; stale cross_attn.* keys are left as unmatched
# leftovers and are not inserted into new_state_dict
for k in new_state_dict:
if k in old_state_dict:
new_state_dict[k] = old_state_dict.pop(k)

if verbose:
print("remaining keys in old_state_dict:", old_state_dict.keys())
self.load_state_dict(new_state_dict)
38 changes: 38 additions & 0 deletions monai/networks/nets/vitautoenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,41 @@ def forward(self, x):
x = self.conv3d_transpose(x)
x = self.conv3d_transpose_1(x)
return x, hidden_states_out

def load_old_state_dict(self, old_state_dict: dict, verbose: bool = False) -> None:
"""
Load a state dict from a ViTAutoEnc model trained with an older version of MONAI where
``CrossAttentionBlock`` was unconditionally instantiated in ``TransformerBlock``
even when ``with_cross_attention=False``. Old checkpoints contain stale
``blocks.{i}.cross_attn.*`` keys that are not present in the current model and
are automatically dropped.

Args:
old_state_dict: state dict from the older ViTAutoEnc model.
verbose: if True, print keys that are missing or unmatched. Defaults to False.
"""
new_state_dict = self.state_dict()
if all(k in new_state_dict for k in old_state_dict):
if verbose:
print("All keys match, loading state dict.")
self.load_state_dict(old_state_dict)
return

if verbose:
for k in new_state_dict:
if k not in old_state_dict:
print(f"key {k} not found in old state dict")
print("----------------------------------------------")
for k in old_state_dict:
if k not in new_state_dict:
print(f"key {k} not found in new state dict")

# copy over all matching keys; stale cross_attn.* keys are left as unmatched
# leftovers and are not inserted into new_state_dict
for k in new_state_dict:
if k in old_state_dict:
new_state_dict[k] = old_state_dict.pop(k)

if verbose:
print("remaining keys in old_state_dict:", old_state_dict.keys())
self.load_state_dict(new_state_dict)
32 changes: 32 additions & 0 deletions tests/networks/blocks/test_transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import numpy as np
import torch
import torch.nn as nn
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.crossattention import CrossAttentionBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.utils import optional_import
from tests.test_utils import dict_product
Expand Down Expand Up @@ -53,6 +55,36 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_is_identity_when_disabled(self):
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False)
# attributes always exist for typing and checkpoint compatibility
self.assertTrue(hasattr(block, "cross_attn"))
self.assertTrue(hasattr(block, "norm_cross_attn"))
# cross_attn is nn.Identity (no parameters) when disabled
self.assertIsInstance(block.cross_attn, nn.Identity)
param_names = [name for name, _ in block.named_parameters()]
self.assertFalse(any(n.startswith("cross_attn") for n in param_names))

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_params_registered_when_enabled(self):
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True)
self.assertIsInstance(block.cross_attn, CrossAttentionBlock)
self.assertTrue(hasattr(block, "norm_cross_attn"))
param_names = [name for name, _ in block.named_parameters()]
self.assertTrue(any(n.startswith("cross_attn.") for n in param_names))
self.assertTrue(any("norm_cross_attn" in n for n in param_names))

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_forward_with_context(self):
hidden_size = 128
block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True)
x = torch.randn(2, 16, hidden_size)
context = torch.randn(2, 8, hidden_size)
with eval_mode(block):
out = block(x, context=context)
self.assertEqual(out.shape, x.shape)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down
32 changes: 32 additions & 0 deletions tests/networks/nets/test_masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,38 @@ def test_masking_ratio(self):

assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens

def test_load_old_state_dict_drops_stale_cross_attn_keys(self):
# simulate an old checkpoint where CrossAttentionBlock was always instantiated
net = MaskedAutoEncoderViT(
in_channels=1,
img_size=(32, 32),
patch_size=(16, 16),
hidden_size=64,
mlp_dim=128,
num_layers=2,
num_heads=4,
decoder_hidden_size=32,
decoder_mlp_dim=64,
decoder_num_layers=2,
decoder_num_heads=4,
spatial_dims=2,
)
old_state = {k: torch.rand_like(v) for k, v in net.state_dict().items()}
# inject stale cross_attn keys from both encoder blocks and decoder blocks
old_state["blocks.0.cross_attn.to_q.weight"] = torch.randn(64, 64)
old_state["blocks.1.cross_attn.out_proj.weight"] = torch.randn(64, 64)
old_state["decoder_blocks.0.cross_attn.to_v.weight"] = torch.randn(32, 32)
old_state["decoder_blocks.1.cross_attn.out_proj.weight"] = torch.randn(32, 32)

# save expected values before the call since load_old_state_dict pops matching keys
expected = {k: v.clone() for k, v in old_state.items() if k in net.state_dict()}
net.load_old_state_dict(old_state)

# all current model keys should be loaded from old_state; stale keys silently dropped
loaded = net.state_dict()
for k in loaded:
self.assertTrue(torch.allclose(loaded[k], expected[k]))


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions tests/networks/nets/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,33 @@ def test_access_attn_matrix(self):
matrix_acess_blk(torch.randn(in_shape))
assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 216, 216)

def test_load_old_state_dict_drops_stale_cross_attn_keys(self):
# simulate an old checkpoint where CrossAttentionBlock was always instantiated
net = ViT(
in_channels=1,
img_size=(32, 32),
patch_size=(16, 16),
hidden_size=64,
mlp_dim=128,
num_layers=2,
num_heads=4,
spatial_dims=2,
)
old_state = {k: torch.rand_like(v) for k, v in net.state_dict().items()}
# inject stale cross_attn keys that the new model no longer has
old_state["blocks.0.cross_attn.to_q.weight"] = torch.randn(64, 64)
old_state["blocks.0.cross_attn.out_proj.weight"] = torch.randn(64, 64)
old_state["blocks.1.cross_attn.to_v.weight"] = torch.randn(64, 64)

# save expected values before the call since load_old_state_dict pops matching keys
expected = {k: v.clone() for k, v in old_state.items() if k in net.state_dict()}
net.load_old_state_dict(old_state)

# all current model keys should be loaded from old_state; stale keys silently dropped
loaded = net.state_dict()
for k in loaded:
self.assertTrue(torch.allclose(loaded[k], expected[k]))


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions tests/networks/nets/test_vitautoenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,33 @@ def test_ill_arg(
dropout_rate=dropout_rate,
)

def test_load_old_state_dict_drops_stale_cross_attn_keys(self):
# simulate an old checkpoint where CrossAttentionBlock was always instantiated
net = ViTAutoEnc(
in_channels=1,
img_size=(32, 32),
patch_size=(16, 16),
hidden_size=64,
mlp_dim=128,
num_layers=2,
num_heads=4,
spatial_dims=2,
)
old_state = {k: torch.rand_like(v) for k, v in net.state_dict().items()}
# inject stale cross_attn keys that the new model no longer has
old_state["blocks.0.cross_attn.to_q.weight"] = torch.randn(64, 64)
old_state["blocks.0.cross_attn.out_proj.weight"] = torch.randn(64, 64)
old_state["blocks.1.cross_attn.to_v.weight"] = torch.randn(64, 64)

# save expected values before the call since load_old_state_dict pops matching keys
expected = {k: v.clone() for k, v in old_state.items() if k in net.state_dict()}
net.load_old_state_dict(old_state)

# all current model keys should be loaded from old_state; stale keys silently dropped
loaded = net.state_dict()
for k in loaded:
self.assertTrue(torch.allclose(loaded[k], expected[k]))


if __name__ == "__main__":
unittest.main()
Loading