From 53229eb82b40ae58b9289c8fc29bb112a4c9b21a Mon Sep 17 00:00:00 2001 From: whning Date: Fri, 19 Jun 2026 14:36:47 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20#6969:=20`=5Ffetch=5Fclass=5Flibrary=5Ftu?= =?UTF-8?q?ple`=20=E5=AF=B9=E9=9D=9E=E6=A8=A1=E5=9D=97=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84=E9=98=B2=E5=BE=A1=E6=80=A7=E6=A3=80?= =?UTF-8?q?=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- b/src/diffusers/pipelines/pipeline_loading_utils.py | 11 +++++++++++ src/diffusers/pipelines/pipeline_loading_utils.py | 5 +++++ 2 files changed, 16 insertions(+) create mode 100644 b/src/diffusers/pipelines/pipeline_loading_utils.py diff --git a/b/src/diffusers/pipelines/pipeline_loading_utils.py b/b/src/diffusers/pipelines/pipeline_loading_utils.py new file mode 100644 index 000000000000..dc9a2169b977 --- /dev/null +++ b/b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -0,0 +1,11 @@ + + +def _fetch_class_library_tuple(module): + # Guard against non-module values (bool, int, str, etc.) that can be passed + # when a positional argument shift occurs in subclasses (e.g. issue #6969). + if not isinstance(module, (torch.nn.Module, type)): + return (None, None) + + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d695f5e7284d..5eb61ddeb460 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -952,6 +952,11 @@ def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> def _fetch_class_library_tuple(module): + # Guard against non-module values (bool, int, str, etc.) that can be passed + # when a positional argument shift occurs in subclasses (e.g. issue #6969). + if not isinstance(module, (torch.nn.Module, type)): + return (None, None) + # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines")