From e14301ba9a9100e8c16fc07a8c8c9368fcfc5d97 Mon Sep 17 00:00:00 2001 From: whning Date: Fri, 19 Jun 2026 14:36:24 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20#13789:=20UNet2DModel=20dtype=20=E5=B1=9E?= =?UTF-8?q?=E6=80=A7=E5=9C=A8=20nn.DataParallel=20=E4=B8=8B=E6=8A=A5=20Unb?= =?UTF-8?q?oundLocalError?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modeling_utils_fixed.py | 21 +++++++++++++++++++++ modeling_utils_original.clean.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 modeling_utils_fixed.py create mode 100644 modeling_utils_original.clean.py diff --git a/modeling_utils_fixed.py b/modeling_utils_fixed.py new file mode 100644 index 000000000000..90cef4587b96 --- /dev/null +++ b/modeling_utils_fixed.py @@ -0,0 +1,21 @@ + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + +from contextlib import ExitStack, contextmanager +from functools import wraps +from pathlib import Path +from typing import Any, Callable, ContextManager, List, Tuple, Type + +import safetensors +import torch diff --git a/modeling_utils_original.clean.py b/modeling_utils_original.clean.py new file mode 100644 index 000000000000..90cef4587b96 --- /dev/null +++ b/modeling_utils_original.clean.py @@ -0,0 +1,21 @@ + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + +from contextlib import ExitStack, contextmanager +from functools import wraps +from pathlib import Path +from typing import Any, Callable, ContextManager, List, Tuple, Type + +import safetensors +import torch