Skip to content
50 changes: 35 additions & 15 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import os

import streamlit as st
from einops import repeat
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.demo.streamlit_helpers import (
get_interactive_image,
get_unique_embedder_keys_from_conditioner,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
lowvram_model_mover,
samples_to_streamlit,
set_lowvram_mode,
)
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_input_image_tensor,
perform_save_locally,
)

SAVE_PATH = "outputs/demo/txt2img/"

Expand Down Expand Up @@ -97,16 +117,7 @@ def load_img(display=True, key=None, device="cuda"):
return None
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image.to(device)
return get_input_image_tensor(image, device=device)


def run_txt2img(
Expand Down Expand Up @@ -143,7 +154,9 @@ def run_txt2img(

if st.button("Sample"):
st.write(f"**Model I:** {version}")
out = do_sample(
st.text("Sampling")
outputs = st.empty()
samples, latents = do_sample(
state["model"],
sampler,
value_dict,
Expand All @@ -153,9 +166,11 @@ def run_txt2img(
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
return_latents=True,
filter=filter,
move_model=lowvram_model_mover,
)
samples_to_streamlit(outputs, samples)
return out


Expand Down Expand Up @@ -194,16 +209,20 @@ def run_img2img(
num_samples = num_rows * num_cols

if st.button("Sample"):
out = do_img2img(
st.text("Sampling")
outputs = st.empty()
samples, latents = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
return_latents=True,
filter=filter,
move_model=lowvram_model_mover,
)
samples_to_streamlit(outputs, samples)
return out


Expand Down Expand Up @@ -244,6 +263,7 @@ def apply_refiner(
skip_encode=True,
filter=filter,
add_noise=not finish_denoising,
move_model=lowvram_model_mover,
)

return samples
Expand Down
Loading