Commit adb6cb76 authored by Billy Cao's avatar Billy Cao

Patch UNet Forward to support resolutions that are not multiples of 64

Also modifed the UI to no longer step in 64
parent 828438b4
......@@ -16,6 +16,7 @@ import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.diffusionmodules.openaimodel
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
......@@ -26,6 +27,7 @@ def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
......
......@@ -5,6 +5,7 @@ import importlib
import torch
from torch import einsum
import torch.nn.functional as F
from ldm.util import default
from einops import rearrange
......@@ -12,6 +13,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
from ldm.modules.diffusionmodules.util import timestep_embedding
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
......@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
if h.shape[-2:] != hs[-1].shape[-2:]:
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
......@@ -380,8 +380,8 @@ def create_seed_inputs():
with gr.Row(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from width", value=0)
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from height", value=0)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
......@@ -715,8 +715,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
......@@ -724,8 +724,8 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass width", value=0)
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(equal_height=True):
......@@ -901,8 +901,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
......@@ -1231,8 +1231,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
process_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
......@@ -1289,8 +1289,8 @@ def create_ui(wrap_gradio_gpu_call):
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
training_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment