diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 2bb0e17de..91348e588 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -13,7 +13,7 @@ EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler) -from sgm.util import load_model_from_config +from sgm.util import load_model_from_config, get_default_device_name class ModelArchitecture(str, Enum): @@ -136,7 +136,7 @@ def __init__( model_id: ModelArchitecture, model_path="checkpoints", config_path="configs/inference", - device="cuda", + device: Optional[str] = None, use_fp16=True, ) -> None: if model_id not in model_specs: @@ -145,10 +145,10 @@ def __init__( self.specs = model_specs[self.model_id] self.config = str(pathlib.Path(config_path, self.specs.config)) self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) - self.device = device + self.device = device or get_default_device_name() self.model = self._load_model(device=device, use_fp16=use_fp16) - def _load_model(self, device="cuda", use_fp16=True): + def _load_model(self, *, device, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 31b0ec3dc..cb111eb47 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -4,13 +4,12 @@ import numpy as np import torch +from PIL import Image from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig -from PIL import Image -from torch import autocast -from sgm.util import append_dims +from sgm.util import append_dims, safe_autocast, get_default_device_name class WatermarkEmbedder: @@ -111,21 +110,24 @@ def do_sample( batch2model_input: Optional[List] = None, return_latents=False, filter=None, - device="cuda", + device: Optional[str] = None, ): + if not device: + device = get_default_device_name() if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] with torch.no_grad(): - with autocast(device) as precision_scope: + with safe_autocast(device): with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, + device=device, ) for key in batch: if isinstance(batch[key], torch.Tensor): @@ -170,7 +172,13 @@ def denoiser(input, sigma, c): return samples -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): +def get_batch( + keys, + value_dict, + N: Union[List, ListConfig], + *, + device: str, +): # Hardcoded demo setups; might undergo some changes in the future batch = {} @@ -227,7 +235,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): return batch, batch_uc -def get_input_image_tensor(image: Image.Image, device="cuda"): +def get_input_image_tensor(image: Image.Image, device: Optional[str] = None): + if not device: + device = get_default_device_name() w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( @@ -252,15 +262,18 @@ def do_img2img( return_latents=False, skip_encode=False, filter=None, - device="cuda", + device: Optional[str] = None, ): + if not device: + device = get_default_device_name() with torch.no_grad(): - with autocast(device) as precision_scope: + with safe_autocast(device): with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], + device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index 2f3efd3c7..9a9ca61bd 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -1,6 +1,6 @@ import math from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import pytorch_lightning as pl import torch @@ -12,8 +12,15 @@ from ..modules.autoencoding.temporal_ae import VideoDecoder from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from ..modules.ema import LitEma -from ..util import (default, disabled_train, get_obj_from_str, - instantiate_from_config, log_txt_as_img) +from ..util import ( + default, + disabled_train, + get_default_device_name, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, + safe_autocast, +) class DiffusionEngine(pl.LightningModule): @@ -114,6 +121,12 @@ def get_input(self, batch): # image tensors should be scaled to -1 ... 1 and in bchw format return batch[self.input_key] + def _first_stage_autocast_context(self): + return safe_autocast( + device=get_default_device_name(), + enabled=not self.disable_first_stage_autocast, + ) + @torch.no_grad() def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z @@ -121,7 +134,7 @@ def decode_first_stage(self, z): n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + with self._first_stage_autocast_context(): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} @@ -139,7 +152,7 @@ def encode_first_stage(self, x): n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) n_rounds = math.ceil(x.shape[0] / n_samples) all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + with self._first_stage_autocast_context(): for n in range(n_rounds): out = self.first_stage_model.encode( x[n * n_samples : (n + 1) * n_samples] diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index e3949dd1e..00f16131a 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -10,9 +10,15 @@ from torch.utils.checkpoint import checkpoint from ...modules.attention import SpatialTransformer -from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear, - normalization, - timestep_embedding, zero_module) +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) from ...modules.video_attention import SpatialVideoTransformer from ...util import exists diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py index af07566d5..93a251541 100644 --- a/sgm/modules/diffusionmodules/sampling.py +++ b/sgm/modules/diffusionmodules/sampling.py @@ -9,11 +9,14 @@ from omegaconf import ListConfig, OmegaConf from tqdm import tqdm -from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step, - linear_multistep_coeff, - to_d, to_neg_log_sigma, - to_sigma) -from ...util import append_dims, default, instantiate_from_config +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) +from ...util import append_dims, default, instantiate_from_config, get_default_device_name DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} @@ -25,8 +28,10 @@ def __init__( num_steps: Union[int, None] = None, guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, verbose: bool = False, - device: str = "cuda", + device: Union[str, None] = None, ): + if device is None: + device = get_default_device_name() self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config( diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index 48bd5ea8b..e99db5af6 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -20,8 +20,17 @@ from ...modules.diffusionmodules.util import (extract_into_tensor, make_beta_schedule) from ...modules.distributions.distributions import DiagonalGaussianDistribution -from ...util import (append_dims, autocast, count_params, default, - disabled_train, expand_dims_like, instantiate_from_config) +from ...util import ( + append_dims, + autocast, + count_params, + default, + disabled_train, + expand_dims_like, + get_default_device_name, + instantiate_from_config, + safe_autocast, +) class AbstractEmbModel(nn.Module): @@ -229,7 +238,9 @@ def forward(self, c): c = c[:, None, :] return c - def get_unconditional_conditioning(self, bs, device="cuda"): + def get_unconditional_conditioning(self, bs, device=None): + if device is None: + device = get_default_device_name() uc_class = ( self.n_classes - 1 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) @@ -254,9 +265,10 @@ class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( - self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + device = device or get_default_device_name() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device @@ -281,7 +293,7 @@ def forward(self, text): return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) - with torch.autocast("cuda", enabled=False): + with safe_autocast(get_default_device_name(), enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z @@ -296,9 +308,10 @@ class FrozenByT5Embedder(AbstractEmbModel): """ def __init__( - self, version="google/byt5-base", device="cuda", max_length=77, freeze=True + self, version="google/byt5-base", device=None, max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + device = device or get_default_device_name() self.tokenizer = ByT5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device @@ -323,7 +336,7 @@ def forward(self, text): return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) - with torch.autocast("cuda", enabled=False): + with safe_autocast(get_default_device_name(), enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z @@ -340,7 +353,7 @@ class FrozenCLIPEmbedder(AbstractEmbModel): def __init__( self, version="openai/clip-vit-large-patch14", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", @@ -348,6 +361,7 @@ def __init__( always_return_pooled=False, ): # clip-vit-base-patch32 super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) @@ -408,7 +422,7 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", @@ -416,6 +430,7 @@ def __init__( legacy=True, ): super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, @@ -510,12 +525,13 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", ): super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version @@ -580,7 +596,7 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, antialias=True, @@ -592,6 +608,7 @@ def __init__( init_device=None, ): super().__init__() + device = device or get_default_device_name() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device(default(init_device, "cpu")), @@ -737,11 +754,12 @@ def __init__( self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", - device="cuda", + device=None, clip_max_length=77, t5_max_length=77, ): super().__init__() + device = device or get_default_device_name() self.clip_encoder = FrozenCLIPEmbedder( clip_version, device, max_length=clip_max_length ) @@ -1006,7 +1024,7 @@ def forward( noise = torch.randn_like(vid) vid = vid + noise * append_dims(sigmas, vid.ndim) - with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): + with safe_autocast(get_default_device_name(), enabled=not self.disable_encoder_autocast): n_samples = ( self.en_and_decode_n_samples_a_time if self.en_and_decode_n_samples_a_time is not None diff --git a/sgm/util.py b/sgm/util.py index 66d9b2a69..6c83c8c1f 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -1,6 +1,7 @@ import functools import importlib import os +from contextlib import nullcontext from functools import partial from inspect import isfunction @@ -11,6 +12,10 @@ from safetensors.torch import load_file as load_safetensors +def get_default_device_name() -> str: + return os.environ.get("SGM_DEFAULT_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") + + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" @@ -273,3 +278,10 @@ def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): current_attribute = getattr(current_attribute, attribute) return (current_attribute, current_key) if return_key else current_attribute + + +def safe_autocast(device, **kwargs): + """Autocast that doesn't crash on devices unsupported by autocast.""" + if device not in ("cpu", "cuda"): + return nullcontext() + return torch.autocast(device, **kwargs)