Commit ef27a18b authored by AUTOMATIC's avatar AUTOMATIC

Hires fix rework

parent fd4461d4
import base64 import base64
import io import io
import math
import os import os
import re import re
from pathlib import Path from pathlib import Path
...@@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): ...@@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None return None
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
if firstpass_width is None or firstpass_height is None:
return
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
width = int(res.get("Size-1", 512))
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
# old algorithm for auto-calculating first pass size
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
res['Hires upscale'] = hr_scale
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None) hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
restore_old_hires_fix_params(res)
return res return res
......
...@@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): ...@@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height): def resize_image(resize_mode, im, width, height, upscaler_name=None):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h): def resize(im, w, h):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height) scale = max(w / im.width, h / im.height)
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0] upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
......
...@@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
self.firstphase_width = firstphase_width self.hr_scale = hr_scale
self.firstphase_height = firstphase_height self.hr_upscaler = hr_upscaler
self.truncate_x = 0
self.truncate_y = 0 if firstphase_width != 0 or firstphase_height != 0:
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
self.hr_scale = self.width / firstphase_width
self.width = firstphase_width
self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr: if self.enable_hr:
...@@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
state.job_count = state.job_count * 2 state.job_count = state.job_count * 2
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" self.extra_generation_params["Hires upscale"] = self.hr_scale
if self.hr_upscaler is not None:
if self.firstphase_width == 0 or self.firstphase_height == 0: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
firstphase_width_truncated = int(scale * self.width)
firstphase_height_truncated = int(scale * self.height)
else:
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
if width_ratio > height_ratio:
firstphase_width_truncated = self.firstphase_width
firstphase_height_truncated = self.firstphase_width * self.height / self.width
else:
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if not self.enable_hr: latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode
if self.enable_hr and latent_scale_mode is None:
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) if not self.enable_hr:
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) return samples
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] target_width = int(self.width * self.hr_scale)
target_height = int(self.height * self.hr_scale)
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index): def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return return
...@@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix: if latent_scale_mode is not None:
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
save_intermediate(samples, i) save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
# Avoid making the inpainting conditioning unless necessary as # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again. # this does need some extra compute to decode / encode the image again.
...@@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i) save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height) image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
batch_images.append(image) batch_images.append(image)
...@@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
......
...@@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { ...@@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
...@@ -545,6 +544,12 @@ opts = Options() ...@@ -545,6 +544,12 @@ opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": "bilinear",
"Latent (nearest)": "nearest",
}
sd_upscalers = [] sd_upscalers = []
sd_model = None sd_model = None
......
...@@ -8,7 +8,7 @@ import modules.processing as processing ...@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling, tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None, hr_scale=hr_scale,
firstphase_height=firstphase_height if enable_hr else None, hr_upscaler=hr_upscaler,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
......
...@@ -684,11 +684,11 @@ def create_ui(): ...@@ -684,11 +684,11 @@ def create_ui():
with gr.Row(): with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
with gr.Row(visible=False) as hr_options: with gr.Row(visible=False) as hr_options:
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
...@@ -729,8 +729,8 @@ def create_ui(): ...@@ -729,8 +729,8 @@ def create_ui():
width, width,
enable_hr, enable_hr,
denoising_strength, denoising_strength,
firstphase_width, hr_scale,
firstphase_height, hr_upscaler,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
...@@ -762,7 +762,6 @@ def create_ui(): ...@@ -762,7 +762,6 @@ def create_ui():
outputs=[hr_options], outputs=[hr_options],
) )
txt2img_paste_fields = [ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"), (txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"), (txt2img_negative_prompt, "Negative prompt"),
...@@ -781,8 +780,8 @@ def create_ui(): ...@@ -781,8 +780,8 @@ def create_ui():
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d), (enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"), (hr_scale, "Hires upscale"),
(firstphase_height, "First pass size-2"), (hr_upscaler, "Hires upscaler"),
*modules.scripts.scripts_txt2img.infotext_fields *modules.scripts.scripts_txt2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
......
...@@ -202,7 +202,7 @@ axis_options = [ ...@@ -202,7 +202,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None),
...@@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): ...@@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model self.model = shared.sd_model
self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
self.vae = opts.sd_vae self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
...@@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): ...@@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object):
hypernetwork.apply_strength() hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment