diff --git a/modeling_utils_fixed.py b/modeling_utils_fixed.py new file mode 100644 index 000000000000..90cef4587b96 --- /dev/null +++ b/modeling_utils_fixed.py @@ -0,0 +1,21 @@ + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + +from contextlib import ExitStack, contextmanager +from functools import wraps +from pathlib import Path +from typing import Any, Callable, ContextManager, List, Tuple, Type + +import safetensors +import torch diff --git a/modeling_utils_original.clean.py b/modeling_utils_original.clean.py new file mode 100644 index 000000000000..90cef4587b96 --- /dev/null +++ b/modeling_utils_original.clean.py @@ -0,0 +1,21 @@ + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + +from contextlib import ExitStack, contextmanager +from functools import wraps +from pathlib import Path +from typing import Any, Callable, ContextManager, List, Tuple, Type + +import safetensors +import torch