From 743da575477845a4234798b9b99e0f0e5579ce84 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:33:31 +0300 Subject: [PATCH 01/10] Don't use star imports in sampling demo --- scripts/demo/sampling.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 2984dbf7a..2dec4757b 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,6 +1,23 @@ +import os + +import numpy as np +import streamlit as st +import torch +from einops import repeat from pytorch_lightning import seed_everything -from scripts.demo.streamlit_helpers import * +from scripts.demo.streamlit_helpers import ( + do_img2img, + do_sample, + get_interactive_image, + get_unique_embedder_keys_from_conditioner, + init_embedder_options, + init_sampling, + init_save_locally, + init_st, + perform_save_locally, + set_lowvram_mode, +) SAVE_PATH = "outputs/demo/txt2img/" From 6e300a19568be9cfb06666ae21cf3a63a54d3a2a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:42:57 +0300 Subject: [PATCH 02/10] Use get_input_image_tensor helper --- scripts/demo/sampling.py | 14 ++------------ scripts/demo/streamlit_helpers.py | 1 - 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 2dec4757b..85e34ae0b 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,8 +1,6 @@ import os -import numpy as np import streamlit as st -import torch from einops import repeat from pytorch_lightning import seed_everything @@ -18,6 +16,7 @@ perform_save_locally, set_lowvram_mode, ) +from sgm.inference.helpers import get_input_image_tensor SAVE_PATH = "outputs/demo/txt2img/" @@ -114,16 +113,7 @@ def load_img(display=True, key=None, device="cuda"): return None if display: st.image(image) - w, h = image.size - print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 - image = image.resize((width, height)) - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - return image.to(device) + return get_input_image_tensor(image, device=device) def run_txt2img( diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 82b7fb9cc..03209ae28 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -12,7 +12,6 @@ from safetensors.torch import load_file as load_safetensors from torch import autocast from torchvision import transforms -from torchvision.utils import make_grid from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.modules.diffusionmodules.sampling import ( From 71e348635a6ea502227d6705b3c9396cff735214 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:48:19 +0300 Subject: [PATCH 03/10] Move Txt2NoisyDiscretizationWrapper to sgm.inference.helpers --- scripts/demo/streamlit_helpers.py | 31 +------------------------------ sgm/inference/helpers.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 03209ae28..97584c3aa 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -14,6 +14,7 @@ from torchvision import transforms from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.inference.helpers import Txt2NoisyDiscretizationWrapper from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, @@ -251,36 +252,6 @@ def __call__(self, *args, **kwargs): return sigmas -class Txt2NoisyDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 0.0, original_steps=None): - self.discretization = discretization - self.strength = strength - self.original_steps = original_steps - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - if self.original_steps is None: - steps = len(sigmas) - else: - steps = self.original_steps + 1 - prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) - sigmas = sigmas[prune_index:] - print("prune index:", prune_index) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - def get_guider(key): guider = st.sidebar.selectbox( f"Discretization #{key}", diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 1c653708b..1a94fffca 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -98,6 +98,36 @@ def __call__(self, *args, **kwargs): return sigmas +class Txt2NoisyDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 0.0, original_steps=None): + self.discretization = discretization + self.strength = strength + self.original_steps = original_steps + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + if self.original_steps is None: + steps = len(sigmas) + else: + steps = self.original_steps + 1 + prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) + sigmas = sigmas[prune_index:] + print("prune index:", prune_index) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + def do_sample( model, sampler, From 3b116d729d0c38d9ccc42dcf524412a7590ca8e4 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:49:34 +0300 Subject: [PATCH 04/10] Use embed_watermark from helpers --- scripts/demo/streamlit_helpers.py | 52 ++----------------------------- 1 file changed, 3 insertions(+), 49 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 97584c3aa..0309bcd58 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -6,7 +6,6 @@ import streamlit as st import torch from einops import rearrange, repeat -from imwatermark import WatermarkEncoder from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors @@ -14,7 +13,7 @@ from torchvision import transforms from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.inference.helpers import Txt2NoisyDiscretizationWrapper +from sgm.inference.helpers import Txt2NoisyDiscretizationWrapper, embed_watermark from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, @@ -26,51 +25,6 @@ from sgm.util import append_dims, instantiate_from_config -class WatermarkEmbedder: - def __init__(self, watermark): - self.watermark = watermark - self.num_bits = len(WATERMARK_BITS) - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def __call__(self, image: torch.Tensor): - """ - Adds a predefined watermark to the input image - - Args: - image: ([N,] B, C, H, W) in range [0, 1] - - Returns: - same as input but watermarked - """ - # watermarking libary expects input as cv2 BGR format - squeeze = len(image.shape) == 4 - if squeeze: - image = image[None, ...] - n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] - # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] - for k in range(image_np.shape[0]): - image_np[k] = self.encoder.encode(image_np[k], "dwtDct") - image = torch.from_numpy( - rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) - ).to(image.device) - image = torch.clamp(image / 255, min=0.0, max=1.0) - if squeeze: - image = image[0] - return image - - -# A fixed 48-bit message that was choosen at random -# WATERMARK_MESSAGE = 0xB3EC907BB19E -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watemark = WatermarkEmbedder(WATERMARK_BITS) - - @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() @@ -209,7 +163,7 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watemark(samples) + samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( @@ -706,7 +660,7 @@ def denoiser(x, sigma, c): if filter is not None: samples = filter(samples) - grid = embed_watemark(torch.stack([samples])) + grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) if return_latents: From 973875f649379b5035cb3bf8e30245ef0550146a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:51:03 +0300 Subject: [PATCH 05/10] Use perform_save_locally from helpers --- scripts/demo/sampling.py | 3 +-- scripts/demo/streamlit_helpers.py | 12 ------------ 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 85e34ae0b..9dd7c7562 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -13,10 +13,9 @@ init_sampling, init_save_locally, init_st, - perform_save_locally, set_lowvram_mode, ) -from sgm.inference.helpers import get_input_image_tensor +from sgm.inference.helpers import get_input_image_tensor, perform_save_locally SAVE_PATH = "outputs/demo/txt2img/" diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 0309bcd58..869f049e1 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -160,18 +160,6 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): return value_dict -def perform_save_locally(save_path, samples): - os.makedirs(os.path.join(save_path), exist_ok=True) - base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watermark(samples) - for sample in samples: - sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") - Image.fromarray(sample.astype(np.uint8)).save( - os.path.join(save_path, f"{base_count:09}.png") - ) - base_count += 1 - - def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: From 3319970103bfa279e019e04266f6d5c5baaa051e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:51:39 +0300 Subject: [PATCH 06/10] Use Img2ImgDiscretizationWrapper from helpers --- scripts/demo/streamlit_helpers.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 869f049e1..ce3681a32 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -13,7 +13,11 @@ from torchvision import transforms from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.inference.helpers import Txt2NoisyDiscretizationWrapper, embed_watermark +from sgm.inference.helpers import ( + Img2ImgDiscretizationWrapper, + Txt2NoisyDiscretizationWrapper, + embed_watermark, +) from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, @@ -170,30 +174,6 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path -class Img2ImgDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 1.0): - self.discretization = discretization - self.strength = strength - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] - print("prune index:", max(int(self.strength * len(sigmas)), 1)) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - def get_guider(key): guider = st.sidebar.selectbox( f"Discretization #{key}", From 56c2f02420a8ba659af968ea3d30efa2042467d1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:52:50 +0300 Subject: [PATCH 07/10] Use load_model_from_config from sgm.util --- scripts/demo/streamlit_helpers.py | 47 ++++++------------------------- sgm/inference/api.py | 1 + sgm/util.py | 1 - 3 files changed, 9 insertions(+), 40 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index ce3681a32..036a342e1 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -8,7 +8,6 @@ from einops import rearrange, repeat from omegaconf import ListConfig, OmegaConf from PIL import Image -from safetensors.torch import load_file as load_safetensors from torch import autocast from torchvision import transforms @@ -26,7 +25,7 @@ HeunEDMSampler, LinearMultistepSampler, ) -from sgm.util import append_dims, instantiate_from_config +from sgm.util import append_dims, load_model_from_config @st.cache_resource() @@ -37,9 +36,14 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): ckpt = version_dict["ckpt"] config = OmegaConf.load(config) - model, msg = load_model_from_config(config, ckpt if load_ckpt else None) + model = load_model_from_config( + config=config, + ckpt=(ckpt if load_ckpt else None), + ) + model = initial_model_load(model) + model.eval() - state["msg"] = msg + state["msg"] = None state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config @@ -76,41 +80,6 @@ def unload_model(model): torch.cuda.empty_cache() -def load_model_from_config(config, ckpt=None, verbose=True): - model = instantiate_from_config(config.model) - - if ckpt is not None: - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - global_step = pl_sd["global_step"] - st.info(f"loaded ckpt from global step {global_step}") - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - msg = None - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - else: - msg = None - - model = initial_model_load(model) - model.eval() - return model, msg - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 12efc064c..1c10a6f22 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -176,6 +176,7 @@ def __init__( def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) + model.eval() if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") model.to(device) diff --git a/sgm/util.py b/sgm/util.py index c5e68f4b5..20a90ab91 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -226,7 +226,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): for param in model.parameters(): param.requires_grad = False - model.eval() return model From 6d083f8d9ae7bf5ea4ad9bc795c8927bfcdda5a9 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:55:44 +0300 Subject: [PATCH 08/10] Use get_unique_embedder_keys_from_conditioner from helpers --- scripts/demo/streamlit_helpers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 036a342e1..7df927044 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -16,6 +16,7 @@ Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, embed_watermark, + get_unique_embedder_keys_from_conditioner, ) from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, @@ -80,10 +81,6 @@ def unload_model(model): torch.cuda.empty_cache() -def get_unique_embedder_keys_from_conditioner(conditioner): - return list(set([x.input_key for x in conditioner.embedders])) - - def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future From f96a27c2cb1495e14b64b9515bf24c07babc3b5a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 10:56:11 +0300 Subject: [PATCH 09/10] Use get_batch from helpers --- scripts/demo/streamlit_helpers.py | 63 ++----------------------------- 1 file changed, 3 insertions(+), 60 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 7df927044..999ba9a9a 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,12 +1,11 @@ import math import os -from typing import List, Union +from typing import List -import numpy as np import streamlit as st import torch from einops import rearrange, repeat -from omegaconf import ListConfig, OmegaConf +from omegaconf import OmegaConf from PIL import Image from torch import autocast from torchvision import transforms @@ -16,6 +15,7 @@ Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, embed_watermark, + get_batch, get_unique_embedder_keys_from_conditioner, ) from sgm.modules.diffusionmodules.sampling import ( @@ -455,63 +455,6 @@ def denoiser(input, sigma, c): return samples -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): - # Hardcoded demo setups; might undergo some changes in the future - - batch = {} - batch_uc = {} - - for key in keys: - if key == "txt": - batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - elif key == "original_size_as_tuple": - batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) - .to(device) - .repeat(*N, 1) - ) - elif key == "crop_coords_top_left": - batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) - .to(device) - .repeat(*N, 1) - ) - elif key == "aesthetic_score": - batch["aesthetic_score"] = ( - torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) - ) - batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) - ) - - elif key == "target_size_as_tuple": - batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]) - .to(device) - .repeat(*N, 1) - ) - else: - batch[key] = value_dict[key] - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - @torch.no_grad() def do_img2img( img, From 1fd8de2165cb34a7bdab476edbe78348f989b775 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Jul 2023 11:20:13 +0300 Subject: [PATCH 10/10] Adapt lowvram mode to helper do_sample and do_txt2img by way of a model mover context, and deduplicate --- scripts/demo/sampling.py | 28 +++- scripts/demo/streamlit_helpers.py | 211 +++--------------------------- sgm/inference/helpers.py | 110 ++++++++++------ 3 files changed, 110 insertions(+), 239 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 9dd7c7562..b75b18288 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -5,17 +5,22 @@ from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import ( - do_img2img, - do_sample, get_interactive_image, get_unique_embedder_keys_from_conditioner, init_embedder_options, init_sampling, init_save_locally, init_st, + lowvram_model_mover, + samples_to_streamlit, set_lowvram_mode, ) -from sgm.inference.helpers import get_input_image_tensor, perform_save_locally +from sgm.inference.helpers import ( + do_img2img, + do_sample, + get_input_image_tensor, + perform_save_locally, +) SAVE_PATH = "outputs/demo/txt2img/" @@ -149,7 +154,9 @@ def run_txt2img( if st.button("Sample"): st.write(f"**Model I:** {version}") - out = do_sample( + st.text("Sampling") + outputs = st.empty() + samples, latents = do_sample( state["model"], sampler, value_dict, @@ -159,9 +166,11 @@ def run_txt2img( C, F, force_uc_zero_embeddings=["txt"] if not is_legacy else [], - return_latents=return_latents, + return_latents=True, filter=filter, + move_model=lowvram_model_mover, ) + samples_to_streamlit(outputs, samples) return out @@ -200,16 +209,20 @@ def run_img2img( num_samples = num_rows * num_cols if st.button("Sample"): - out = do_img2img( + st.text("Sampling") + outputs = st.empty() + samples, latents = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], sampler, value_dict, num_samples, force_uc_zero_embeddings=["txt"] if not is_legacy else [], - return_latents=return_latents, + return_latents=True, filter=filter, + move_model=lowvram_model_mover, ) + samples_to_streamlit(outputs, samples) return out @@ -250,6 +263,7 @@ def apply_refiner( skip_encode=True, filter=filter, add_noise=not finish_denoising, + move_model=lowvram_model_mover, ) return samples diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 999ba9a9a..27baafcdf 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,22 +1,17 @@ -import math import os -from typing import List +from contextlib import contextmanager import streamlit as st import torch from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image -from torch import autocast from torchvision import transforms from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import ( Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, - embed_watermark, - get_batch, - get_unique_embedder_keys_from_conditioner, ) from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, @@ -26,7 +21,7 @@ HeunEDMSampler, LinearMultistepSampler, ) -from sgm.util import append_dims, load_model_from_config +from sgm.util import load_model_from_config @st.cache_resource() @@ -53,10 +48,6 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): return state -def load_model(model): - model.cuda() - - lowvram_mode = False @@ -74,11 +65,19 @@ def initial_model_load(model): return model -def unload_model(model): - global lowvram_mode - if lowvram_mode: - model.cpu() - torch.cuda.empty_cache() +@contextmanager +def lowvram_model_mover(model, device): + """ + Context manager that moves the model to the device and back to CPU + afterwards if lowvram_mode is set to True. + """ + try: + model.to(device) + yield + finally: + if lowvram_mode: + model.cpu() + torch.cuda.empty_cache() def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): @@ -367,179 +366,7 @@ def get_init_img(batch_size=1, key=None): return init_image -def do_sample( - model, - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings: List = None, - batch2model_input: List = None, - return_latents=False, - filter=None, -): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - if batch2model_input is None: - batch2model_input = [] - - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - num_samples = [num_samples] - load_model(model.conditioner) - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - unload_model(model.conditioner) - - for k in c: - if not k == "crossattn": - c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) - ) - - additional_model_inputs = {} - for k in batch2model_input: - additional_model_inputs[k] = batch[k] - - shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to("cuda") - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - load_model(model.denoiser) - load_model(model.model) - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - unload_model(model.model) - unload_model(model.denoiser) - - load_model(model.first_stage_model) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - unload_model(model.first_stage_model) - - if filter is not None: - samples = filter(samples) - - grid = torch.stack([samples]) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - - if return_latents: - return samples, samples_z - return samples - - -@torch.no_grad() -def do_img2img( - img, - model, - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=[], - additional_kwargs={}, - offset_noise_level: int = 0.0, - return_latents=False, - skip_encode=False, - filter=None, - add_noise=True, -): - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - load_model(model.conditioner) - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - unload_model(model.conditioner) - for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) - - for k in additional_kwargs: - c[k] = uc[k] = additional_kwargs[k] - if skip_encode: - z = img - else: - load_model(model.first_stage_model) - z = model.encode_first_stage(img) - unload_model(model.first_stage_model) - - noise = torch.randn_like(z) - - sigmas = sampler.discretization(sampler.num_steps).cuda() - sigma = sigmas[0] - - st.info(f"all sigmas: {sigmas}") - st.info(f"noising sigma: {sigma}") - if offset_noise_level > 0.0: - noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim - ) - if add_noise: - noised_z = z + noise * append_dims(sigma, z.ndim).cuda() - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. - else: - noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) - - def denoiser(x, sigma, c): - return model.denoiser(model.model, x, sigma, c) - - load_model(model.denoiser) - load_model(model.model) - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - unload_model(model.model) - unload_model(model.denoiser) - - load_model(model.first_stage_model) - samples_x = model.decode_first_stage(samples_z) - unload_model(model.first_stage_model) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = embed_watermark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - if return_latents: - return samples, samples_z - return samples +def samples_to_streamlit(outputs, samples): + grid = torch.stack([samples]) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 1a94fffca..1fcac5819 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -1,3 +1,4 @@ +import contextlib import os from typing import Union, List, Optional @@ -128,6 +129,16 @@ def __call__(self, *args, **kwargs): return sigmas +@contextlib.contextmanager +def default_model_mover(model, device): + """ + Default model mover: ensure the model is loaded to `device` on entry, + do not unload on exit + """ + model.to(device) + yield + + def do_sample( model, sampler, @@ -142,6 +153,7 @@ def do_sample( return_latents=False, filter=None, device="cuda", + move_model=default_model_mover, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] @@ -149,26 +161,27 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device): with model.ema_scope(): num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with move_model(model.conditioner, device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: if not k == "crossattn": @@ -188,9 +201,13 @@ def denoiser(input, sigma, c): model.model, input, sigma, c, **additional_model_inputs ) - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + with move_model(model.denoiser, device): + with move_model(model.model, device): + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + + with move_model(model.first_stage_model, device): + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) @@ -283,21 +300,23 @@ def do_img2img( skip_encode=False, filter=None, device="cuda", + add_noise=True, + move_model=default_model_mover, ): with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device): with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - + with move_model(model.conditioner, device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) @@ -306,8 +325,11 @@ def do_img2img( if skip_encode: z = img else: - z = model.encode_first_stage(img) + with move_model(model.first_stage_model, device): + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) @@ -315,16 +337,24 @@ def do_img2img( noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) - noised_z = z + noise * append_dims(sigma, z.ndim) - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + + if add_noise: + noised_z = z + noise * append_dims(sigma, z.ndim) + # Note: hardcoded to DDPM-like scaling. need to generalize later. + noised_z = noised_z / torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) + with move_model(model.denoiser, device): + with move_model(model.model, device): + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + + with move_model(model.first_stage_model, device): + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: