diff --git a/b/src/diffusers/pipelines/pipeline_loading_utils.py b/b/src/diffusers/pipelines/pipeline_loading_utils.py new file mode 100644 index 000000000000..dc9a2169b977 --- /dev/null +++ b/b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -0,0 +1,11 @@ + + +def _fetch_class_library_tuple(module): + # Guard against non-module values (bool, int, str, etc.) that can be passed + # when a positional argument shift occurs in subclasses (e.g. issue #6969). + if not isinstance(module, (torch.nn.Module, type)): + return (None, None) + + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d695f5e7284d..5eb61ddeb460 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -952,6 +952,11 @@ def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> def _fetch_class_library_tuple(module): + # Guard against non-module values (bool, int, str, etc.) that can be passed + # when a positional argument shift occurs in subclasses (e.g. issue #6969). + if not isinstance(module, (torch.nn.Module, type)): + return (None, None) + # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines")