Commit adb6cb76 authored by Billy Cao's avatar Billy Cao

Patch UNet Forward to support resolutions that are not multiples of 64

Also modifed the UI to no longer step in 64
parent 828438b4
...@@ -16,6 +16,7 @@ import ldm.modules.attention ...@@ -16,6 +16,7 @@ import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
import ldm.modules.diffusionmodules.openaimodel
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
...@@ -26,6 +27,7 @@ def apply_optimizations(): ...@@ -26,6 +27,7 @@ def apply_optimizations():
undo_optimizations() undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
......
...@@ -5,6 +5,7 @@ import importlib ...@@ -5,6 +5,7 @@ import importlib
import torch import torch
from torch import einsum from torch import einsum
import torch.nn.functional as F
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
...@@ -12,6 +13,8 @@ from einops import rearrange ...@@ -12,6 +13,8 @@ from einops import rearrange
from modules import shared from modules import shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from ldm.modules.diffusionmodules.util import timestep_embedding
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try: try:
...@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x): ...@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
return x + out return x + out
except NotImplementedError: except NotImplementedError:
return cross_attention_attnblock_forward(self, x) return cross_attention_attnblock_forward(self, x)
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
if h.shape[-2:] != hs[-1].shape[-2:]:
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
...@@ -380,8 +380,8 @@ def create_seed_inputs(): ...@@ -380,8 +380,8 @@ def create_seed_inputs():
with gr.Row(visible=False) as seed_extra_row_2: with gr.Row(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2) seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from width", value=0)
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from height", value=0)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
...@@ -715,8 +715,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -715,8 +715,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Group(): with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
with gr.Row(): with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
...@@ -724,8 +724,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -724,8 +724,8 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False) enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options: with gr.Row(visible=False) as hr_options:
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) firstphase_width = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass width", value=0)
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) firstphase_height = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
...@@ -901,8 +901,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -901,8 +901,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group(): with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512, elem_id="img2img_height")
with gr.Row(): with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
...@@ -1231,8 +1231,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1231,8 +1231,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Preprocess images"): with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory') process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row(): with gr.Row():
...@@ -1289,8 +1289,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1289,8 +1289,8 @@ def create_ui(wrap_gradio_gpu_call):
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment